v0.10.1rc1
This commit is contained in:
27
vllm_ascend/__init__.py
Normal file
27
vllm_ascend/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
|
||||
def register():
|
||||
"""Register the NPU platform."""
|
||||
|
||||
return "vllm_ascend.platform.NPUPlatform"
|
||||
|
||||
|
||||
def register_model():
|
||||
from .models import register_model
|
||||
register_model()
|
||||
215
vllm_ascend/ascend_config.py
Normal file
215
vllm_ascend/ascend_config.py
Normal file
@@ -0,0 +1,215 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
from vllm.logger import logger
|
||||
|
||||
TORCHAIR_MODEL_LIST = ["deepseek", "pangu", "kimi_k2", "qwen"]
|
||||
|
||||
|
||||
def _check_torchair_supported(model_type: str):
|
||||
for supported_model in TORCHAIR_MODEL_LIST:
|
||||
if supported_model in model_type.lower():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class AscendConfig:
|
||||
"""
|
||||
Configuration Object for additional_config from vllm.configs.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config):
|
||||
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
|
||||
|
||||
torchair_graph_config = additional_config.get("torchair_graph_config",
|
||||
{})
|
||||
self.torchair_graph_config = TorchairGraphConfig(torchair_graph_config)
|
||||
|
||||
ascend_scheduler_config = additional_config.get(
|
||||
"ascend_scheduler_config", {})
|
||||
self.ascend_scheduler_config = AscendSchedulerConfig(
|
||||
ascend_scheduler_config)
|
||||
|
||||
self.expert_map_path = additional_config.get("expert_map_path", None)
|
||||
self.chunked_prefill_for_mla = additional_config.get(
|
||||
"chunked_prefill_for_mla", False)
|
||||
self.enable_shared_expert_dp = additional_config.get(
|
||||
"enable_shared_expert_dp", False
|
||||
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
|
||||
self.enable_prefetch = additional_config.get("enable_prefetch", False)
|
||||
self.lmhead_tensor_parallel_size = additional_config.get(
|
||||
"lmhead_tensor_parallel_size", None)
|
||||
if self.lmhead_tensor_parallel_size is not None:
|
||||
logger.info(
|
||||
f"Enable lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size} in pure DP scenario"
|
||||
)
|
||||
if vllm_config.parallel_config.tensor_parallel_size != 1:
|
||||
raise AssertionError(
|
||||
"lmhead_tensor_parallel_size is only supported in the pure DP scenario"
|
||||
)
|
||||
|
||||
|
||||
class TorchairGraphConfig:
|
||||
"""
|
||||
Configuration Object for torchair_graph_config from additional_config
|
||||
"""
|
||||
|
||||
def __init__(self, torchair_graph_config):
|
||||
self.enabled = torchair_graph_config.get("enabled", False)
|
||||
self.mode = torchair_graph_config.get("mode", '')
|
||||
self.use_cached_graph = torchair_graph_config.get(
|
||||
"use_cached_graph", False)
|
||||
self.use_cached_kv_cache_bytes = torchair_graph_config.get(
|
||||
"use_cached_kv_cache_bytes", False)
|
||||
self.graph_batch_sizes = torchair_graph_config.get(
|
||||
"graph_batch_sizes", [])
|
||||
self.graph_batch_sizes_init = torchair_graph_config.get(
|
||||
"graph_batch_sizes_init", False)
|
||||
self.enable_multistream_mla = torchair_graph_config.get(
|
||||
"enable_multistream_mla", False)
|
||||
self.enable_multistream_moe = torchair_graph_config.get(
|
||||
"enable_multistream_moe", False)
|
||||
self.enable_view_optimize = torchair_graph_config.get(
|
||||
"enable_view_optimize", True)
|
||||
self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False)
|
||||
|
||||
if not isinstance(self.graph_batch_sizes, list):
|
||||
raise TypeError("graph_batch_sizes must be list[int]")
|
||||
if self.graph_batch_sizes_init and len(self.graph_batch_sizes) > 0:
|
||||
raise ValueError(
|
||||
"graph_batch_sizes_init is only valid when graph_batch_sizes is empty"
|
||||
)
|
||||
if not self.enabled:
|
||||
if self.mode:
|
||||
raise RuntimeError(
|
||||
"mode is valid only when Torchair graph mode is enabled")
|
||||
if self.use_cached_graph:
|
||||
raise RuntimeError(
|
||||
"use_cached_graph is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.use_cached_kv_cache_bytes:
|
||||
raise RuntimeError(
|
||||
"use_cached_kv_cache_bytes is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.graph_batch_sizes:
|
||||
raise RuntimeError(
|
||||
"graph_batch_sizes is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.graph_batch_sizes_init:
|
||||
raise RuntimeError(
|
||||
"graph_batch_sizes_init is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.enable_multistream_mla:
|
||||
raise RuntimeError(
|
||||
"enable_multistream_mla is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.enable_multistream_moe:
|
||||
raise RuntimeError(
|
||||
"enable_multistream_moe is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.enable_kv_nz:
|
||||
raise RuntimeError(
|
||||
"enable_kv_nz is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.use_cached_kv_cache_bytes and not self.use_cached_graph:
|
||||
raise RuntimeError(
|
||||
"use_cached_kv_cache_bytes is valid only when Torchair graph mode and use_cached_graph are enabled"
|
||||
)
|
||||
|
||||
|
||||
class AscendSchedulerConfig:
|
||||
"""
|
||||
Configuration Object for ascend_scheduler_config from additional_config
|
||||
"""
|
||||
|
||||
def __init__(self, ascend_scheduler_config: dict):
|
||||
self.enabled = ascend_scheduler_config.get("enabled", False)
|
||||
# Ascend scheduler is based on vllm v0 scheduler, so we should support
|
||||
# all vllm v0 scheduler configs as well.
|
||||
for k, v in ascend_scheduler_config.items():
|
||||
if not hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
_ASCEND_CONFIG: Optional[AscendConfig] = None
|
||||
|
||||
|
||||
def init_ascend_config(vllm_config):
|
||||
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
|
||||
refresh = additional_config.get("refresh",
|
||||
False) if additional_config else False
|
||||
global _ASCEND_CONFIG
|
||||
if _ASCEND_CONFIG is not None and not refresh:
|
||||
return _ASCEND_CONFIG
|
||||
_ASCEND_CONFIG = AscendConfig(vllm_config)
|
||||
return _ASCEND_CONFIG
|
||||
|
||||
|
||||
def clear_ascend_config():
|
||||
global _ASCEND_CONFIG
|
||||
_ASCEND_CONFIG = None
|
||||
|
||||
|
||||
def get_ascend_config():
|
||||
global _ASCEND_CONFIG
|
||||
if _ASCEND_CONFIG is None:
|
||||
raise RuntimeError(
|
||||
"Ascend config is not initialized. Please call init_ascend_config first."
|
||||
)
|
||||
return _ASCEND_CONFIG
|
||||
|
||||
|
||||
def check_ascend_config(vllm_config, enforce_eager):
|
||||
ascend_config = get_ascend_config()
|
||||
|
||||
# for eager mode
|
||||
if enforce_eager:
|
||||
# torchair_graph cannot be enabled with eager mode.
|
||||
if ascend_config.torchair_graph_config.enabled:
|
||||
raise RuntimeError(
|
||||
"Can't enable graph mode and eager mode at the same time. Please set `enforce_eager=False` if you attempt to enable NPU graph mode."
|
||||
)
|
||||
# for graph mode
|
||||
else:
|
||||
# torchair_graph case
|
||||
if ascend_config.torchair_graph_config.enabled:
|
||||
# torchair_graph is supported for deepseek/pangu/qwen model only.
|
||||
if vllm_config.model_config:
|
||||
model_type = vllm_config.model_config.hf_config.model_type
|
||||
if not _check_torchair_supported(model_type):
|
||||
raise NotImplementedError(
|
||||
"Torchair graph mode only works with following model types:"
|
||||
f"{TORCHAIR_MODEL_LIST}.")
|
||||
if ascend_config.enable_shared_expert_dp:
|
||||
logger.warning(
|
||||
"enable_shared_expert_dp is not supported for torchair graph mode currently, "
|
||||
"it has been disabled automatically.")
|
||||
# aclgraph case
|
||||
else:
|
||||
# aclgraph doesn't work with deepseek model and only qwen model is well tested.
|
||||
if vllm_config.model_config:
|
||||
model_type = vllm_config.model_config.hf_config.model_type
|
||||
if "deepseek" in model_type:
|
||||
raise NotImplementedError(
|
||||
"ACL Graph does not support deepseek. Please "
|
||||
"try torchair graph mode to serve deepseek models on vllm-ascend."
|
||||
" Or set `enforce_eager=True` to use eager mode.")
|
||||
if "qwen" not in model_type:
|
||||
logger.warning(
|
||||
"ACL Graph is currently experimental. Please "
|
||||
"raise an issue on https://github.com/vllm-project/vllm-ascend/issues"
|
||||
" if you encourage any Error")
|
||||
138
vllm_ascend/ascend_forward_context.py
Normal file
138
vllm_ascend/ascend_forward_context.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import math
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed import (get_dp_group, get_ep_group,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.forward_context import (BatchDescriptor, get_forward_context,
|
||||
set_forward_context)
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
|
||||
|
||||
class FusedMoEState(Enum):
|
||||
AllGather = 0
|
||||
All2All = 1
|
||||
MC2 = 2
|
||||
AllGatherEP = 3
|
||||
NaiveMulticast = 4
|
||||
All2AllSeq = 5
|
||||
|
||||
|
||||
# TODO(zzzzwwjj): add soc_version to choose branch
|
||||
def _get_fused_moe_state(ep_size: int, with_prefill: bool,
|
||||
is_deepseek_v3_r1: bool):
|
||||
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
||||
# only supports deepseek v3/r1
|
||||
if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
||||
and is_deepseek_v3_r1):
|
||||
return FusedMoEState.AllGatherEP
|
||||
elif ep_size == 1:
|
||||
if with_prefill:
|
||||
return FusedMoEState.NaiveMulticast
|
||||
else:
|
||||
return FusedMoEState.AllGather
|
||||
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
|
||||
elif ep_size < 16 or with_prefill:
|
||||
return FusedMoEState.All2All
|
||||
else:
|
||||
return FusedMoEState.MC2
|
||||
|
||||
|
||||
def get_dispatcher_name(ep_size: int, with_prefill: bool) -> str:
|
||||
if ep_size == 1:
|
||||
return "TokenDispatcherWithAllGather"
|
||||
|
||||
if ep_size < 16:
|
||||
return "TokenDispatcherWithAll2AllV"
|
||||
|
||||
if with_prefill:
|
||||
return "TokenDispatcherWithAll2AllV"
|
||||
return "TokenDispatcherWithMC2"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_ascend_forward_context(
|
||||
attn_metadata: Any,
|
||||
vllm_config: VllmConfig,
|
||||
virtual_engine: int = 0,
|
||||
num_tokens: Optional[int] = None,
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||
with_prefill: bool = True,
|
||||
in_profile_run: bool = False,
|
||||
reserved_mc2_mask: Optional[torch.Tensor] = None,
|
||||
moe_comm_method: str = "",
|
||||
num_actual_tokens: Optional[int] = None,
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor: Optional[BatchDescriptor] = None):
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc.
|
||||
We add some additional param into forward_context.
|
||||
"""
|
||||
with set_forward_context(
|
||||
attn_metadata,
|
||||
vllm_config,
|
||||
virtual_engine=virtual_engine,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=aclgraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
):
|
||||
forward_context = get_forward_context()
|
||||
forward_context.moe_comm_method_name = moe_comm_method + "commimpl"
|
||||
forward_context.with_prefill = with_prefill
|
||||
ep_size = (get_ep_group().world_size if
|
||||
vllm_config.parallel_config.enable_expert_parallel else 1)
|
||||
|
||||
is_deepseek_v3_r1 = hasattr(
|
||||
vllm_config.model_config.hf_config, 'n_routed_experts'
|
||||
) and vllm_config.model_config.hf_config.n_routed_experts == 256
|
||||
fused_moe_state = _get_fused_moe_state(ep_size, with_prefill,
|
||||
is_deepseek_v3_r1)
|
||||
forward_context.fused_moe_state = fused_moe_state
|
||||
forward_context.in_profile_run = in_profile_run
|
||||
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
|
||||
get_token_dispatcher
|
||||
dispatcher_name = get_dispatcher_name(ep_size, with_prefill)
|
||||
dispatcher = get_token_dispatcher(dispatcher_name)
|
||||
forward_context.token_dispatcher = dispatcher
|
||||
|
||||
# NOTE: This cannot be set using set_forward_context
|
||||
# due to multiple warmups before actual capturing
|
||||
forward_context.capturing = False
|
||||
|
||||
if num_tokens is None and attn_metadata is not None:
|
||||
num_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
dp_world_size = get_dp_group().world_size
|
||||
if dp_world_size > 1 and forward_context.dp_metadata is not None:
|
||||
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item(
|
||||
)
|
||||
else:
|
||||
max_tokens_across_dp = num_tokens
|
||||
|
||||
forward_context.max_tokens_across_dp = max_tokens_across_dp
|
||||
|
||||
if num_tokens is not None:
|
||||
if num_actual_tokens is None:
|
||||
num_actual_tokens = num_tokens
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
# NOTE: token num which need to pad to when mc2
|
||||
forward_context.padded_num_tokens = math.ceil(
|
||||
max_tokens_across_dp / tp_world_size) * tp_world_size
|
||||
|
||||
if reserved_mc2_mask is not None:
|
||||
mc2_mask = reserved_mc2_mask[:forward_context.
|
||||
padded_num_tokens]
|
||||
mc2_mask[:num_actual_tokens] = True
|
||||
mc2_mask[num_actual_tokens:] = False
|
||||
forward_context.mc2_mask = mc2_mask
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
pass
|
||||
0
vllm_ascend/attention/__init__.py
Normal file
0
vllm_ascend/attention/__init__.py
Normal file
93
vllm_ascend/attention/attention_mask.py
Normal file
93
vllm_ascend/attention/attention_mask.py
Normal file
@@ -0,0 +1,93 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
|
||||
def _generate_attn_mask(max_seq_len, dtype):
|
||||
# Construct lower triangle matrix.
|
||||
mask_flag = torch.tril(
|
||||
torch.ones((max_seq_len, max_seq_len),
|
||||
dtype=torch.bool)).view(max_seq_len, max_seq_len)
|
||||
# Create upper triangle matrix used to mark mask positions.
|
||||
mask_flag = ~mask_flag
|
||||
# Currently for fp16 dtype, the mask value should be set to -inf.
|
||||
# TODO: Eliminate this part in the future.
|
||||
if dtype == torch.float16:
|
||||
mask_value = torch.finfo(torch.float32).min
|
||||
else:
|
||||
mask_value = 1
|
||||
attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)),
|
||||
mask_flag, mask_value).to(dtype)
|
||||
return attn_mask
|
||||
|
||||
|
||||
class AttentionMaskBuilder:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
attn_mask = _generate_attn_mask(max_seq_len, dtype)
|
||||
|
||||
self._seq_len_cached = attn_mask.shape[0]
|
||||
self.attn_mask_cache = attn_mask
|
||||
|
||||
@staticmethod
|
||||
def get_mask_scale_factor(dtype: torch.dtype = torch.float16):
|
||||
if dtype == torch.float16:
|
||||
mask_scale_factor = 1
|
||||
elif dtype == torch.bfloat16:
|
||||
mask_scale_factor = -10000
|
||||
else:
|
||||
raise ValueError(
|
||||
"The current operation now only supports data types: torch.float16 and "
|
||||
"torch.bfloat16. Please ensure the input is of one of these types."
|
||||
)
|
||||
return mask_scale_factor
|
||||
|
||||
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
|
||||
device: torch.device):
|
||||
self._update_attn_cache(max_seq_len, dtype)
|
||||
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
|
||||
).to(device)
|
||||
|
||||
def get_splitfuse_attn_mask(
|
||||
self,
|
||||
seq_lens: torch.Tensor,
|
||||
position: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
if dtype not in [torch.float16, torch.bfloat16]:
|
||||
raise ValueError(
|
||||
"splitfuse_attn_mask now only supports bf16 and fp16")
|
||||
max_seq_len = max(seq_lens, default=0)
|
||||
self._update_attn_cache(max_seq_len, dtype)
|
||||
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
|
||||
# is not the same. Fix this in the future when kernel is ready.
|
||||
mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(dtype)
|
||||
attn_mask = torch.index_select(self.attn_mask_cache,
|
||||
dim=0,
|
||||
index=position)[:, :max_seq_len]
|
||||
attn_mask *= mask_scale_factor
|
||||
return attn_mask.contiguous().to(device, non_blocking=True)
|
||||
|
||||
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
|
||||
if seqlen > self._seq_len_cached:
|
||||
self._seq_len_cached = seqlen
|
||||
self.attn_mask_cache = _generate_attn_mask(seqlen, dtype)
|
||||
if self.attn_mask_cache.dtype != dtype:
|
||||
self.attn_mask_cache = self.attn_mask_cache.to(dtype)
|
||||
604
vllm_ascend/attention/attention_v1.py
Normal file
604
vllm_ascend/attention/attention_v1.py
Normal file
@@ -0,0 +1,604 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.utils import cdiv, direct_register_custom_op
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
||||
nd_to_nz_2d, nd_to_nz_spec)
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
|
||||
|
||||
class AscendAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ASCEND"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
|
||||
return AscendAttentionBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AscendMetadata"]:
|
||||
return AscendMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
|
||||
return AscendAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
if is_310p():
|
||||
return (2, num_blocks, num_kv_heads * head_size // 16, block_size,
|
||||
16)
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_bsh_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size, num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: List[torch.Tensor],
|
||||
dst_kv_cache: List[torch.Tensor],
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
|
||||
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
|
||||
src_indices = src_to_dst[:, 0]
|
||||
dst_indices = src_to_dst[:, 1]
|
||||
|
||||
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
|
||||
dst_key_cache.device)
|
||||
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
|
||||
dst_key_cache.device)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
src_indices = src_to_dists[:, 0]
|
||||
dst_indices = src_to_dists[:, 1]
|
||||
|
||||
for kv_cache in kv_caches:
|
||||
key_caches = kv_cache[0]
|
||||
value_caches = kv_cache[1]
|
||||
key_caches[dst_indices] = key_caches[src_indices]
|
||||
value_caches[dst_indices] = value_caches[src_indices]
|
||||
|
||||
|
||||
class AscendAttentionState(Enum):
|
||||
PrefillNoCache = 0
|
||||
PrefillCacheHit = 1
|
||||
DecodeOnly = 2
|
||||
ChunkedPrefill = 3
|
||||
SpecDecoding = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendMetadata:
|
||||
|
||||
# **************************** Basic Properties ************************** #
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
# Current state of this attention run.
|
||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||
|
||||
# Number of tokens excluding padding.
|
||||
num_actual_tokens: int = 0
|
||||
|
||||
# The sequence length per sequence. Sequence length means the computed
|
||||
# tokens + new tokens (is None if it is a decoding).
|
||||
# (batch_size,)
|
||||
seq_lens: torch.Tensor = None
|
||||
|
||||
query_start_loc: torch.Tensor = None
|
||||
query_lens: torch.Tensor = None
|
||||
# Maximum query length in the batch (None for decoding).
|
||||
max_query_len: Optional[int] = None
|
||||
|
||||
# ********************** KV Cache Related Properties ********************* #
|
||||
# Block addresses per sequence (Seq id -> list of physical block).
|
||||
# (batch_size, max_blocks_per_seq)
|
||||
block_tables: torch.Tensor = None
|
||||
|
||||
# The indices of the token slots that input tokens will be stored into.
|
||||
# E.g., if `slot_mapping` is [35, 2, 17] and the block size is 16, the
|
||||
# three tokens are stored in the 3rd slot in block 2, 2nd slot in block 0,
|
||||
# and 1st slot in block 1, respectively.
|
||||
# (num_tokens,)
|
||||
slot_mapping: torch.Tensor = None
|
||||
|
||||
# *************************** Other Properties *************************** #
|
||||
enable_dbo_across_dp: bool = False
|
||||
is_only_prefill: bool = False
|
||||
|
||||
|
||||
class AscendAttentionMetadataBuilder:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.device = device
|
||||
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
|
||||
vllm_config.cache_config.block_size)
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return False
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||
num_reqs
|
||||
+ 1]
|
||||
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
block_table[:num_reqs, :self.max_num_blocks_per_req] = (
|
||||
block_table[:num_reqs])
|
||||
|
||||
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
|
||||
num_actual_tokens].to(
|
||||
self.device,
|
||||
non_blocking=
|
||||
True)
|
||||
attn_mask = common_attn_metadata.attn_mask
|
||||
attn_state = common_attn_metadata.attn_state
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||
num_reqs
|
||||
+ 1]
|
||||
query_start_loc = query_start_loc_cpu.to(self.device,
|
||||
non_blocking=True)
|
||||
|
||||
if is_310p():
|
||||
if attn_state == AscendAttentionState.PrefillNoCache:
|
||||
mask_nz = nd_to_nz_2d(attn_mask)
|
||||
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
elif attn_state == AscendAttentionState.ChunkedPrefill:
|
||||
mask_nz = nd_to_nz_spec(attn_mask)
|
||||
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
attn_metadata = AscendMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
block_tables=block_table,
|
||||
query_start_loc=query_start_loc,
|
||||
query_lens=query_lens,
|
||||
seq_lens=seq_lens,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state,
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
|
||||
is_only_prefill=common_attn_metadata.is_only_prefill)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class AscendAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.hidden_size = self.num_heads * self.head_size
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.sliding_window = sliding_window
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes,
|
||||
dtype=torch.float32,
|
||||
device="npu")
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.attn_type = attn_type
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.key_cache = None
|
||||
self.value_cache = None
|
||||
|
||||
def _repeat_kv(self, hidden_states: torch.Tensor,
|
||||
n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, None, :, :].expand(
|
||||
num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(num_key_value_heads * n_rep, slen,
|
||||
head_dim)
|
||||
|
||||
def _forward_prefill_no_cache(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
num_tokens=0,
|
||||
) -> torch.Tensor:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
|
||||
mask = attn_metadata.attn_mask
|
||||
|
||||
if is_310p():
|
||||
# align q k v output tensors
|
||||
query = aligned_16(query)
|
||||
key = aligned_16(key)
|
||||
value = aligned_16(value)
|
||||
output = aligned_16(output)
|
||||
# do reformat in case of broadcasted tensors
|
||||
mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1)
|
||||
mask = torch_npu.npu_format_cast(mask.contiguous(),
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
if self.sliding_window is not None and \
|
||||
attn_metadata.attn_mask.shape[0] > self.sliding_window:
|
||||
|
||||
key = self._repeat_kv(key, self.num_heads // self.num_kv_heads)
|
||||
value = self._repeat_kv(value, self.num_heads // self.num_kv_heads)
|
||||
|
||||
output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="TND",
|
||||
pre_tokens=self.sliding_window,
|
||||
scale=self.scale,
|
||||
actual_seq_lengths=attn_metadata.seq_lens,
|
||||
actual_seq_lengths_kv=attn_metadata.seq_lens)
|
||||
output = output.view(num_tokens, self.num_heads, self.head_size)
|
||||
else:
|
||||
torch_npu._npu_flash_attention(query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seq_len=attn_metadata.seq_lens,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output)
|
||||
assert output is not None
|
||||
return output[:num_tokens, :, :]
|
||||
|
||||
def _forward_prefill_cache_hit(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
|
||||
compress_mask = attn_metadata.attn_mask
|
||||
batch_size = attn_metadata.query_lens.shape[0]
|
||||
block_table = attn_metadata.block_tables[:batch_size, :]
|
||||
|
||||
torch_npu._npu_flash_attention_qlens(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
block_table=block_table,
|
||||
mask=compress_mask,
|
||||
seq_len=attn_metadata.query_lens,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output)
|
||||
return output
|
||||
|
||||
def _forward_decode_only(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if is_310p():
|
||||
# seq_lens_tensor needs to be transferred to the device for 310P.
|
||||
attn_metadata.seq_lens = \
|
||||
attn_metadata.seq_lens.to(device=query.device)
|
||||
if self.sliding_window is not None:
|
||||
batch_size = attn_metadata.seq_lens.shape[0]
|
||||
block_size = 128
|
||||
query = query.view(batch_size, 1, self.num_heads * self.head_size)
|
||||
key = self.key_cache
|
||||
value = self.value_cache
|
||||
if self.key_cache is not None and self.value_cache is not None:
|
||||
block_size = self.key_cache.shape[1]
|
||||
key = self.key_cache.flatten(2, 3).contiguous()
|
||||
value = self.value_cache.flatten(2, 3).contiguous()
|
||||
|
||||
output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="BSH",
|
||||
block_size=block_size,
|
||||
pre_tokens=self.sliding_window,
|
||||
scale=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
actual_seq_lengths=[1] * len(attn_metadata.seq_lens),
|
||||
actual_seq_lengths_kv=attn_metadata.seq_lens)
|
||||
|
||||
output = output.view(batch_size, self.num_heads, self.head_size)
|
||||
else:
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
return output
|
||||
|
||||
def _forward_v1_style(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# Use chunked prefill for head size 192 scenario, like deepseek
|
||||
# paged_attention_splitfuse maybe crash at such scenario.
|
||||
# TODO: vanilla path will be removed after the kernel support
|
||||
# head_size 192 scenario.
|
||||
if self.head_size == 192:
|
||||
cu_seqlen_q = [0] + attn_metadata.query_lens.tolist()
|
||||
cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist()
|
||||
cu_seqlen_q = torch.tensor(cu_seqlen_q, device=query.device)
|
||||
cu_seqlen_k = torch.tensor(cu_seqlen_k, device=query.device)
|
||||
cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0)
|
||||
cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0)
|
||||
max_seqlen_q = torch.max(attn_metadata.query_lens)
|
||||
max_seqlen_k = torch.max(attn_metadata.seq_lens)
|
||||
vanilla_chunked_prefill(output, query, self.key_cache,
|
||||
self.value_cache,
|
||||
attn_metadata.block_tables, cu_seqlen_q,
|
||||
cu_seqlen_k, max_seqlen_q, max_seqlen_k,
|
||||
self.scale, None, True)
|
||||
return output
|
||||
|
||||
# Use paged attention.
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
|
||||
if is_310p():
|
||||
# Do reformat in case of broadcasted tensors.
|
||||
attn_metadata.attn_mask = \
|
||||
torch_npu.npu_format_cast(attn_metadata.attn_mask.contiguous(),
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
attn_metadata.seq_lens = \
|
||||
attn_metadata.seq_lens.to(device=query.device)
|
||||
|
||||
torch_npu._npu_paged_attention_splitfuse(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
mask=attn_metadata.attn_mask,
|
||||
block_table=attn_metadata.block_tables,
|
||||
seq_len=attn_metadata.query_lens,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output)
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor],
|
||||
attn_metadata: AscendMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
trace_flag: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Ascend attention.
|
||||
Args:
|
||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
kv_cache: shape = [key_cache, value_cache]
|
||||
key_cache = [num_blocks, block_size,
|
||||
num_kv_heads, head_size]
|
||||
value_cache = [num_blocks, block_size,
|
||||
num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [batch_size * seq_len, num_heads, head_size]
|
||||
"""
|
||||
num_tokens = query.shape[0]
|
||||
use_kv_cache_int8 = len(
|
||||
kv_cache) > 0 and kv_cache[0].dtype == torch.int8
|
||||
if output is None:
|
||||
output = torch.empty(num_tokens,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
ori_output = output
|
||||
if trace_flag:
|
||||
torch.ops.vllm.unified_ascend_attention_with_output(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
output=output,
|
||||
layer_name=layer.layer_name)
|
||||
|
||||
elif hasattr(layer, 'quant_method') and use_kv_cache_int8:
|
||||
output = layer.quant_method.apply(layer, query, key, value,
|
||||
kv_cache, attn_metadata,
|
||||
self.attn_type, self.scale,
|
||||
output)
|
||||
|
||||
else:
|
||||
if attn_metadata is None:
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
attn_type = self.attn_type
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
# View q k v to BSH.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
# TODO: Remove this contiguous in the future.
|
||||
value = value.contiguous()
|
||||
|
||||
if len(kv_cache) > 1:
|
||||
if self.key_cache is None:
|
||||
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
||||
slots = attn_metadata.slot_mapping
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key=key[:num_actual_tokens],
|
||||
value=value[:num_actual_tokens],
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_indices=slots)
|
||||
|
||||
# V0-Style scheduler situation.
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
output = self._forward_prefill_no_cache(
|
||||
query, key, value, attn_metadata, output, num_tokens)
|
||||
elif attn_metadata.attn_state == \
|
||||
AscendAttentionState.PrefillCacheHit:
|
||||
output = self._forward_prefill_cache_hit(
|
||||
query, attn_metadata, output)
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
output = self._forward_decode_only(query, attn_metadata,
|
||||
output)
|
||||
# Normal V1 situation.
|
||||
else:
|
||||
output = self._forward_v1_style(query, attn_metadata, output)
|
||||
|
||||
# to make in-place change to the output tensor
|
||||
if hasattr(layer, 'quant_method') and use_kv_cache_int8:
|
||||
output = output.view(num_tokens, self.num_heads, self.head_size)
|
||||
ori_output[:, :, :] = output[:num_tokens, :, :]
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
|
||||
|
||||
def unified_ascend_attention_with_output(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output,
|
||||
trace_flag=False)
|
||||
return
|
||||
|
||||
|
||||
def unified_attention_with_output_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_ascend_attention_with_output",
|
||||
op_func=unified_ascend_attention_with_output,
|
||||
mutates_args=["output"],
|
||||
fake_impl=unified_attention_with_output_fake,
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
1050
vllm_ascend/attention/mla_v1.py
Normal file
1050
vllm_ascend/attention/mla_v1.py
Normal file
File diff suppressed because it is too large
Load Diff
95
vllm_ascend/attention/utils.py
Normal file
95
vllm_ascend/attention/utils.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendCommonAttentionMetadata:
|
||||
"""
|
||||
Per-batch attention metadata, shared across layers and backends.
|
||||
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
||||
|
||||
For many of the tensors we keep both GPU and CPU versions.
|
||||
"""
|
||||
|
||||
query_start_loc: torch.Tensor
|
||||
query_start_loc_cpu: torch.Tensor
|
||||
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
||||
|
||||
seq_lens_cpu: torch.Tensor
|
||||
"""(batch_size,), the length of each request including both computed tokens
|
||||
and newly scheduled tokens"""
|
||||
|
||||
num_reqs: int
|
||||
"""Number of requests"""
|
||||
num_actual_tokens: int
|
||||
"""Total number of tokens in batch"""
|
||||
|
||||
max_query_len: int
|
||||
"""Max token number of request in batch"""
|
||||
|
||||
decode_token_per_req: int
|
||||
"""decode token number per request"""
|
||||
|
||||
block_table_tensor: torch.Tensor
|
||||
|
||||
slot_mapping_cpu: torch.Tensor
|
||||
|
||||
actual_seq_lengths_q: list[int]
|
||||
|
||||
positions: torch.Tensor = None
|
||||
|
||||
attn_mask: torch.Tensor = None
|
||||
|
||||
spec_attn_mask: torch.Tensor = None
|
||||
|
||||
attn_state: Any = None
|
||||
|
||||
enable_dbo_across_dp: bool = False
|
||||
|
||||
is_only_prefill: bool = False
|
||||
|
||||
graph_pad_size: int = -1
|
||||
|
||||
|
||||
def split_decodes_and_prefills(
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
decode_threshold: int = 1,
|
||||
) -> tuple[int, int, int, int]:
|
||||
"""
|
||||
Assuming a reordered batch, finds the boundary between prefill and decode
|
||||
requests.
|
||||
|
||||
Args:
|
||||
common_attn_metadata: AscendCommonAttentionMetadata object containing the
|
||||
batch metadata.
|
||||
decode_threshold: The maximum query length to be considered a decode.
|
||||
|
||||
Returns:
|
||||
num_decodes: The number of decode requests.
|
||||
num_prefills: The number of prefill requests.
|
||||
num_decode_tokens: The number of tokens in the decode requests.
|
||||
num_prefill_tokens: The number of tokens in the prefill requests.
|
||||
"""
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
query_start_loc = common_attn_metadata.query_start_loc_cpu
|
||||
|
||||
if max_query_len <= decode_threshold:
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
is_prefill = query_lens > decode_threshold
|
||||
if not torch.any(is_prefill):
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
|
||||
first_prefill = is_prefill.int().argmax(dim=-1).item()
|
||||
assert torch.all(query_lens[first_prefill:] >= decode_threshold)
|
||||
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
|
||||
num_decodes = first_prefill
|
||||
num_prefills = num_reqs - num_decodes
|
||||
num_decode_tokens = query_start_loc[first_prefill].item()
|
||||
num_prefill_tokens = num_tokens - num_decode_tokens
|
||||
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
|
||||
0
vllm_ascend/compilation/__init__.py
Normal file
0
vllm_ascend/compilation/__init__.py
Normal file
185
vllm_ascend/compilation/acl_graph.py
Normal file
185
vllm_ascend/compilation/acl_graph.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.cuda_graph import CUDAGraphOptions
|
||||
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import weak_ref_tensors
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ACLGraphEntry:
|
||||
batch_descriptor: BatchDescriptor
|
||||
aclgraph: Optional[torch.npu.NPUGraph] = None
|
||||
output: Optional[Any] = None
|
||||
|
||||
# for aclgraph debugging, track the input addresses
|
||||
# during capture, and check if they are the same during replay
|
||||
input_addresses: Optional[list[int]] = None
|
||||
|
||||
|
||||
class ACLGraphWrapper:
|
||||
"""Wraps a runnable to add acl graph capturing and replaying ability. And
|
||||
provide attribute access to the underlying `runnable` via `__getattr__`.
|
||||
|
||||
The workflow of this wrapper in the aclgraph dispatching is as follows:
|
||||
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
|
||||
PIECEWISE).
|
||||
2. At runtime, the wrapper receives a runtime_mode and a
|
||||
batch_descriptor(key) from the forward context and blindly trust them
|
||||
for aclgraph dispatching.
|
||||
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
|
||||
wrapper, just call the runnable directly.
|
||||
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
|
||||
the wrapper will perform aclgraph capture(if key does not exist, create
|
||||
a new entry and cache it) or replay (if key exists in the cache).
|
||||
|
||||
Note: ACLGraphWrapper does not store persistent buffers or copy any
|
||||
runtime inputs into that buffers for replay. We assume implementing them
|
||||
is done outside of the wrapper. That is because we do not make any
|
||||
assumption on the dynamic shape (batch size) of the runtime inputs, as a
|
||||
trade-off for staying orthogonal to compilation logic. Nevertheless,
|
||||
tracing and checking the input addresses to be consistent during replay is
|
||||
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
runnable: Callable,
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode,
|
||||
graph_pool: Any = None,
|
||||
cudagraph_options: Optional[CUDAGraphOptions] = None):
|
||||
self.runnable = runnable
|
||||
self.vllm_config = vllm_config
|
||||
self.graph_pool = graph_pool
|
||||
self.runtime_mode = runtime_mode
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
|
||||
self.first_run_finished = False
|
||||
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
||||
|
||||
# assert runtime_mode is not NONE(no aclgraph), otherwise, we don't
|
||||
# need to initialize a ACLGraphWrapper.
|
||||
assert self.runtime_mode != CUDAGraphMode.NONE
|
||||
if self.graph_pool is None:
|
||||
self.graph_pool = current_platform.get_global_graph_pool()
|
||||
|
||||
if cudagraph_options is None:
|
||||
cudagraph_options = CUDAGraphOptions()
|
||||
self.aclgraph_options = cudagraph_options
|
||||
# the entries for different batch descriptors that we need to capture
|
||||
# aclgraphs for.
|
||||
self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry]\
|
||||
= {}
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
# allow accessing the attributes of the runnable.
|
||||
if hasattr(self.runnable, key):
|
||||
return getattr(self.runnable, key)
|
||||
raise AttributeError(f"Attribute {key} not exists in the runnable of "
|
||||
f"aclgraph wrapper: {self.runnable}")
|
||||
|
||||
def unwrap(self) -> Callable:
|
||||
# in case we need to access the original runnable.
|
||||
return self.runnable
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
forward_context = get_forward_context()
|
||||
batch_descriptor = forward_context.batch_descriptor
|
||||
aclgraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
||||
|
||||
if aclgraph_runtime_mode == CUDAGraphMode.NONE or \
|
||||
aclgraph_runtime_mode != self.runtime_mode:
|
||||
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
|
||||
# running without aclgraphs.
|
||||
# We do not trigger capture/replay if the runtime mode is not
|
||||
# matches. This enables properly dispatching to the correct
|
||||
# CUDAGraphWrapper when nesting multiple instances with different
|
||||
# runtime modes.
|
||||
return self.runnable(*args, **kwargs)
|
||||
|
||||
if batch_descriptor not in self.concrete_aclgraph_entries:
|
||||
# create a new entry for this batch descriptor
|
||||
self.concrete_aclgraph_entries[batch_descriptor] = \
|
||||
ACLGraphEntry(batch_descriptor=batch_descriptor)
|
||||
|
||||
entry = self.concrete_aclgraph_entries[batch_descriptor]
|
||||
|
||||
if entry.aclgraph is None:
|
||||
if self.aclgraph_options.debug_log_enable:
|
||||
# Since we capture aclgraph for many different shapes and
|
||||
# capturing is fast, we don't need to log it for every
|
||||
# shape. E.g. we only log it for the first subgraph in
|
||||
# piecewise mode.
|
||||
logger.debug("Capturing a aclgraph on (%s,%s)",
|
||||
self.runtime_mode.name, entry.batch_descriptor)
|
||||
# validate that aclgraph capturing is legal at this point.
|
||||
validate_cudagraph_capturing_enabled()
|
||||
|
||||
input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
entry.input_addresses = input_addresses
|
||||
aclgraph = torch.npu.NPUGraph()
|
||||
|
||||
with ExitStack() as stack:
|
||||
if self.aclgraph_options.gc_disable:
|
||||
# during every model forward for piecewise aclgraph
|
||||
# mode, we will capture many pieces of aclgraphs
|
||||
# (roughly one per layer). running gc again and again
|
||||
# across layers will make the aclgraph capture very slow.
|
||||
# therefore, we only run gc for the first graph,
|
||||
# and disable gc for the rest of the graphs.
|
||||
stack.enter_context(patch("gc.collect", lambda: None))
|
||||
stack.enter_context(
|
||||
patch("torch.npu.empty_cache", lambda: None))
|
||||
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
with torch.npu.graph(aclgraph, pool=self.graph_pool):
|
||||
# `output` is managed by pytorch's aclgraph pool
|
||||
output = self.runnable(*args, **kwargs)
|
||||
if self.aclgraph_options.weak_ref_output:
|
||||
# by converting it to weak ref,
|
||||
# the original `output` will immediately be released
|
||||
# to save memory. It is only safe to do this for
|
||||
# the last graph in piecewise aclgraph mode, because
|
||||
# the output of the last graph will not be used by
|
||||
# any other acl graph.
|
||||
output = weak_ref_tensors(output)
|
||||
|
||||
# here we always use weak ref for the output
|
||||
# to save memory
|
||||
entry.output = weak_ref_tensors(output)
|
||||
entry.aclgraph = aclgraph
|
||||
|
||||
compilation_counter.num_cudagraph_captured += 1
|
||||
|
||||
# important: we need to return the output, rather than
|
||||
# the weak ref of the output, so that pytorch can correctly
|
||||
# manage the memory during acl graph capture
|
||||
return output
|
||||
|
||||
if self.is_debugging_mode:
|
||||
# check if the input addresses are the same
|
||||
new_input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
assert new_input_addresses == entry.input_addresses, (
|
||||
f"Input addresses for aclgraphs are different "
|
||||
f"during replay. Expected {entry.input_addresses}, "
|
||||
f"got {new_input_addresses}")
|
||||
|
||||
logger.info_once("Replaying aclgraph")
|
||||
entry.aclgraph.replay()
|
||||
return entry.output
|
||||
0
vllm_ascend/core/__init__.py
Normal file
0
vllm_ascend/core/__init__.py
Normal file
84
vllm_ascend/core/schedule_config.py
Normal file
84
vllm_ascend/core/schedule_config.py
Normal file
@@ -0,0 +1,84 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Type, Union
|
||||
|
||||
from vllm.config import SchedulerConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendSchedulerConfig(SchedulerConfig):
|
||||
enable_chunked_prefill: bool = False
|
||||
policy: str = "fcfs"
|
||||
num_scheduler_steps: int = 1
|
||||
scheduler_cls: Union[str, Type[object]] = (
|
||||
"vllm_ascend.core.scheduler.AscendScheduler")
|
||||
|
||||
@classmethod
|
||||
def initialize_from_config(
|
||||
cls,
|
||||
vllm_scheduler_config: SchedulerConfig,
|
||||
ascend_scheduler_config,
|
||||
):
|
||||
scheduler_config = {
|
||||
field.name: getattr(vllm_scheduler_config, field.name)
|
||||
for field in fields(vllm_scheduler_config) if field.init
|
||||
}
|
||||
# Override default values into original SchedulerConfig
|
||||
scheduler_config["enable_chunked_prefill"] = False
|
||||
scheduler_config["policy"] = "fcfs"
|
||||
scheduler_config["num_scheduler_steps"] = 1
|
||||
scheduler_config["scheduler_cls"] = (
|
||||
"vllm_ascend.core.scheduler.AscendScheduler")
|
||||
# Override params in original SchedulerConfig with params in ascend_scheduler_config
|
||||
for k, _ in scheduler_config.items():
|
||||
if hasattr(ascend_scheduler_config, k):
|
||||
scheduler_config[k] = getattr(ascend_scheduler_config, k)
|
||||
return cls(**scheduler_config)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.max_num_encoder_input_tokens = self.max_num_batched_tokens
|
||||
self.encoder_cache_size = self.max_num_batched_tokens
|
||||
self.chunked_prefill_enabled = self.enable_chunked_prefill
|
||||
if (self.max_num_batched_tokens < self.max_model_len
|
||||
and not self.chunked_prefill_enabled):
|
||||
raise ValueError(
|
||||
"Ascend scheduler is enabled without chunked prefill feature. "
|
||||
f"Argument max_num_batched_tokens ({self.max_num_batched_tokens}) is "
|
||||
f"smaller than max_model_len ({self.max_model_len}). "
|
||||
"This effectively limits the maximum sequence length to "
|
||||
"max_num_batched_tokens and makes vLLM reject longer "
|
||||
"sequences. Please increase max_num_batched_tokens or "
|
||||
"decrease max_model_len.")
|
||||
if self.policy != "fcfs":
|
||||
raise NotImplementedError(
|
||||
f"currently AscendScheduler only supports fcfs policy, got {self.policy}"
|
||||
)
|
||||
if self.is_multimodal_model:
|
||||
raise NotImplementedError(
|
||||
"currently AscendScheduler only supports LLM models.")
|
||||
if self.num_scheduler_steps > 1:
|
||||
raise NotImplementedError(
|
||||
"currently AscendScheduler doesn't support multi-step.")
|
||||
if self.send_delta_data:
|
||||
raise NotImplementedError(
|
||||
"currently AscendScheduler doesn't support send_delta_data.")
|
||||
if self.delay_factor > 0:
|
||||
raise NotImplementedError(
|
||||
"currently AscendScheduler doesn't support scheduler_delay_factor."
|
||||
)
|
||||
538
vllm_ascend/core/scheduler.py
Normal file
538
vllm_ascend/core/scheduler.py
Normal file
@@ -0,0 +1,538 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Iterable, Union
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import KVEventBatch
|
||||
from vllm.logger import logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutputs
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
else:
|
||||
KVCacheBlocks = None
|
||||
|
||||
|
||||
class AscendScheduler(Scheduler):
|
||||
"""This Scheduler extends vllm's original v1 scheduler
|
||||
with prefill-first scheduling strategy."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
structured_output_manager: StructuredOutputManager,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
include_finished_set: bool = False,
|
||||
log_stats: bool = False,
|
||||
) -> None:
|
||||
super().__init__(vllm_config, kv_cache_config,
|
||||
structured_output_manager, mm_registry,
|
||||
include_finished_set, log_stats)
|
||||
self.scheduled_req_ids: set[str] = set()
|
||||
self.running: list[Request] = []
|
||||
|
||||
def schedule(self) -> SchedulerOutput:
|
||||
if self.scheduler_config.chunked_prefill_enabled:
|
||||
return super().schedule()
|
||||
scheduled_new_reqs: list[Request] = []
|
||||
scheduled_resumed_reqs: list[Request] = []
|
||||
scheduled_running_reqs: list[Request] = []
|
||||
preempted_reqs: list[Request] = []
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
req_to_new_block_ids: dict[str, list[list[int]]] = {}
|
||||
else:
|
||||
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
|
||||
num_scheduled_tokens: dict[str, int] = {}
|
||||
token_budget = self.max_num_scheduled_tokens
|
||||
# Spec decode-related.
|
||||
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
|
||||
|
||||
# For logging.
|
||||
scheduled_timestamp = time.monotonic()
|
||||
|
||||
# Record scheduled LoRA requests.
|
||||
scheduled_loras: set[int] = set()
|
||||
|
||||
# Use a temporary deque to collect requests that need to be skipped
|
||||
# and put back at the head of the waiting queue later
|
||||
skipped_waiting_requests: deque[Request] = deque()
|
||||
|
||||
# Schedule prefill requests first.
|
||||
while self.waiting and token_budget > 0:
|
||||
if len(self.running) == self.max_num_running_reqs:
|
||||
break
|
||||
|
||||
request = self.waiting[0]
|
||||
|
||||
def skip_cur_request():
|
||||
self.waiting.popleft()
|
||||
skipped_waiting_requests.appendleft(request)
|
||||
|
||||
# P/D: skip request if still waiting for remote kvs.
|
||||
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
||||
is_ready = self._update_waiting_for_remote_kv(request)
|
||||
if is_ready:
|
||||
request.status = RequestStatus.WAITING
|
||||
else:
|
||||
skip_cur_request()
|
||||
continue
|
||||
|
||||
# Check that adding the request still respects the max_loras
|
||||
# constraint.
|
||||
if (self.lora_config and request.lora_request and
|
||||
(len(scheduled_loras) == self.lora_config.max_loras
|
||||
and request.lora_request.lora_int_id not in scheduled_loras)):
|
||||
# Scheduling would exceed max_loras, skip.
|
||||
skip_cur_request()
|
||||
continue
|
||||
|
||||
num_external_computed_tokens = 0
|
||||
load_kv_async = False
|
||||
|
||||
# Get already-cached tokens.
|
||||
if request.num_computed_tokens == 0:
|
||||
new_computed_blocks, num_new_local_computed_tokens = \
|
||||
self.kv_cache_manager.get_computed_blocks(
|
||||
request)
|
||||
|
||||
# Get externally-cached tokens if using a KVConnector.
|
||||
if self.connector is not None:
|
||||
num_external_computed_tokens, load_kv_async = (
|
||||
self.connector.get_num_new_matched_tokens(
|
||||
request, num_new_local_computed_tokens))
|
||||
|
||||
# Total computed tokens (local + external).
|
||||
num_computed_tokens = (num_new_local_computed_tokens +
|
||||
num_external_computed_tokens)
|
||||
else:
|
||||
# P/D: skip checking prefix cache if loaded from remote kvs.
|
||||
new_computed_blocks = (
|
||||
self.kv_cache_manager.create_empty_block_list())
|
||||
num_new_local_computed_tokens = 0
|
||||
num_computed_tokens = request.num_computed_tokens
|
||||
|
||||
# P/D: loading remote KV, do not allocate for new work.
|
||||
if load_kv_async:
|
||||
assert num_external_computed_tokens > 0
|
||||
num_new_tokens = 0
|
||||
blocks = None
|
||||
# Number of tokens to be scheduled.
|
||||
else:
|
||||
prompt_limit = self._get_prompt_limit(request)
|
||||
# We use `request.num_tokens` instead of
|
||||
# `request.num_prompt_tokens` to consider the resumed
|
||||
# requests, which have output tokens.
|
||||
num_new_tokens = request.num_tokens - num_computed_tokens
|
||||
max_tokens_in_kvcache = (self.kv_cache_config.num_blocks *
|
||||
self.block_size)
|
||||
prompt_limit = min(prompt_limit, max_tokens_in_kvcache)
|
||||
|
||||
# Finish request that exceeds prompt_limit or kv cache size.
|
||||
if num_new_tokens > prompt_limit:
|
||||
logger.warning(
|
||||
"Input prompt (%d tokens) is too long"
|
||||
" and exceeds limit of %d",
|
||||
num_new_tokens,
|
||||
prompt_limit,
|
||||
)
|
||||
request.status = RequestStatus.FINISHED_IGNORED
|
||||
self.finished_req_ids.add( # type: ignore
|
||||
request.request_id) # type: ignore
|
||||
self.waiting.popleft()
|
||||
continue
|
||||
|
||||
if num_new_tokens > token_budget:
|
||||
# Scheduling would exceed token_budget, skip.
|
||||
skip_cur_request()
|
||||
continue
|
||||
assert num_new_tokens > 0
|
||||
blocks = new_computed_blocks.blocks[0]
|
||||
|
||||
watermark = getattr(self.scheduler_config, "watermark", 0.01)
|
||||
if not self._check_watermark_for_prefill(request, num_new_tokens,
|
||||
blocks, watermark):
|
||||
# Scheduling would exceed watermark, skip.
|
||||
skip_cur_request()
|
||||
continue
|
||||
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens + num_external_computed_tokens,
|
||||
num_new_local_computed_tokens,
|
||||
new_computed_blocks=new_computed_blocks,
|
||||
num_lookahead_tokens=self.num_lookahead_tokens,
|
||||
delay_cache_blocks=load_kv_async)
|
||||
if new_blocks is None:
|
||||
# The request cannot be scheduled.
|
||||
break
|
||||
|
||||
# KVConnector: update internal state after allocation.
|
||||
# This information is used to determine if a load is
|
||||
# needed for this request.
|
||||
if self.connector is not None:
|
||||
self.connector.update_state_after_alloc(
|
||||
request,
|
||||
new_computed_blocks + new_blocks,
|
||||
num_external_computed_tokens,
|
||||
)
|
||||
|
||||
self.waiting.popleft()
|
||||
if load_kv_async:
|
||||
# If loading async, allocate memory and put request
|
||||
# into the WAITING_FOR_REMOTE_KV state.
|
||||
skipped_waiting_requests.appendleft(request)
|
||||
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
continue
|
||||
|
||||
self.running.append(request)
|
||||
if self.log_stats:
|
||||
request.record_event(EngineCoreEventType.SCHEDULED,
|
||||
scheduled_timestamp)
|
||||
self.scheduled_req_ids.add(request.request_id)
|
||||
# Check request status.
|
||||
if request.status == RequestStatus.WAITING:
|
||||
scheduled_new_reqs.append(request)
|
||||
elif request.status == RequestStatus.PREEMPTED:
|
||||
scheduled_resumed_reqs.append(request)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid request status: {request.status}")
|
||||
|
||||
if self.lora_config and request.lora_request:
|
||||
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
req_to_new_block_ids[request.request_id] = (
|
||||
self.kv_cache_manager.get_block_ids(request.request_id))
|
||||
else:
|
||||
req_to_new_blocks[
|
||||
request.request_id] = self.kv_cache_manager.get_blocks(
|
||||
request.request_id)
|
||||
# Update request info.
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
request.status = RequestStatus.RUNNING
|
||||
request.num_computed_tokens = num_computed_tokens
|
||||
# Count the number of prefix cached tokens.
|
||||
if request.num_cached_tokens < 0:
|
||||
request.num_cached_tokens = num_computed_tokens
|
||||
|
||||
# Put back any skipped requests at the head of the waiting queue
|
||||
if skipped_waiting_requests:
|
||||
self.waiting.extendleft(skipped_waiting_requests)
|
||||
|
||||
# If no prefill requests are scheduled,
|
||||
# Schedule decode requests next.
|
||||
if len(self.scheduled_req_ids) == 0:
|
||||
req_index = 0
|
||||
while req_index < len(self.running) and token_budget > 0:
|
||||
request = self.running[req_index]
|
||||
if request.request_id in self.scheduled_req_ids:
|
||||
# This request has already been scheduled.
|
||||
req_index += 1
|
||||
continue
|
||||
|
||||
num_new_tokens = (request.num_tokens_with_spec -
|
||||
request.num_computed_tokens)
|
||||
assert (request.num_tokens - request.num_computed_tokens) == 1
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
# Make sure the input position does not exceed the max model len.
|
||||
# This is necessary when using spec decoding.
|
||||
num_new_tokens = min(
|
||||
num_new_tokens,
|
||||
self.max_model_len - request.num_computed_tokens)
|
||||
# Check that adding the request still respects the max_loras
|
||||
# constraint.
|
||||
if self.lora_config and request.lora_request and (
|
||||
len(scheduled_loras) == self.lora_config.max_loras
|
||||
and request.lora_request.lora_int_id
|
||||
not in scheduled_loras):
|
||||
# Scheduling would exceed max_loras, skip.
|
||||
num_new_tokens = 0
|
||||
|
||||
if num_new_tokens == 0:
|
||||
# The request cannot be scheduled because one of the following
|
||||
# reason:
|
||||
# 1. No new tokens to schedule. This may happen when PP>1 and
|
||||
# we have already scheduled all prompt tokens but they are
|
||||
# not finished yet.
|
||||
# 2. Adding the request exceeds the max_loras constraint.
|
||||
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
|
||||
# we do not strictly follow the FCFS scheduling policy and
|
||||
# allow the lower-priority requests to be scheduled.
|
||||
req_index += 1
|
||||
continue
|
||||
|
||||
while True:
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens,
|
||||
num_lookahead_tokens=self.num_lookahead_tokens)
|
||||
if new_blocks is None:
|
||||
# The request cannot be scheduled.
|
||||
# Preempt the lowest-priority request.
|
||||
preempted_req = self.running.pop()
|
||||
self.kv_cache_manager.free(preempted_req)
|
||||
preempted_req.status = RequestStatus.PREEMPTED
|
||||
preempted_req.num_computed_tokens = 0
|
||||
if self.log_stats:
|
||||
preempted_req.record_event(
|
||||
EngineCoreEventType.PREEMPTED,
|
||||
scheduled_timestamp)
|
||||
self.waiting.appendleft(preempted_req)
|
||||
preempted_reqs.append(preempted_req)
|
||||
if preempted_req == request:
|
||||
# No more request to preempt.
|
||||
can_schedule = False
|
||||
break
|
||||
else:
|
||||
# The request can be scheduled.
|
||||
can_schedule = True
|
||||
break
|
||||
if not can_schedule:
|
||||
break
|
||||
assert new_blocks is not None
|
||||
|
||||
# Schedule the request.
|
||||
scheduled_running_reqs.append(request)
|
||||
self.scheduled_req_ids.add(request.request_id)
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
req_to_new_block_ids[request.request_id] = (
|
||||
new_blocks.get_block_ids())
|
||||
else:
|
||||
req_to_new_blocks[request.request_id] = new_blocks
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
req_index += 1
|
||||
|
||||
# Speculative decode related.
|
||||
if request.spec_token_ids:
|
||||
num_scheduled_spec_tokens = (num_new_tokens +
|
||||
request.num_computed_tokens -
|
||||
request.num_tokens)
|
||||
if num_scheduled_spec_tokens > 0:
|
||||
# Trim spec_token_ids list to num_scheduled_spec_tokens.
|
||||
del request.spec_token_ids[num_scheduled_spec_tokens:]
|
||||
scheduled_spec_decode_tokens[request.request_id] = (
|
||||
request.spec_token_ids)
|
||||
|
||||
# Record scheduled LoRA requests.
|
||||
if self.lora_config and request.lora_request:
|
||||
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||
|
||||
# Check if the scheduling constraints are satisfied.
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
||||
assert token_budget >= 0
|
||||
assert len(self.running) <= self.max_num_running_reqs
|
||||
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(
|
||||
scheduled_running_reqs) <= len(self.running)
|
||||
|
||||
# Get the longest common prefix among all requests in the running queue.
|
||||
# This can be potentially used for cascade attention.
|
||||
num_common_prefix_blocks = [0] * len(
|
||||
self.kv_cache_config.kv_cache_groups)
|
||||
if self.running:
|
||||
any_request = self.running[0]
|
||||
num_common_prefix_blocks = (
|
||||
self.kv_cache_manager.get_num_common_prefix_blocks(
|
||||
any_request, len(self.running)))
|
||||
|
||||
# Construct the scheduler output.
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(
|
||||
req, req_to_new_block_ids[req.request_id])
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
cached_reqs_data = self._make_cached_request_data(
|
||||
scheduled_running_reqs, scheduled_resumed_reqs,
|
||||
num_scheduled_tokens, scheduled_spec_decode_tokens,
|
||||
req_to_new_block_ids)
|
||||
else:
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(
|
||||
req, req_to_new_blocks[req.request_id].get_block_ids())
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
|
||||
cached_reqs_data = self._make_cached_request_data(
|
||||
scheduled_running_reqs, scheduled_resumed_reqs,
|
||||
num_scheduled_tokens, scheduled_spec_decode_tokens,
|
||||
req_to_new_blocks)
|
||||
scheduled_cached_reqs = cached_reqs_data
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=new_reqs_data,
|
||||
scheduled_cached_reqs=scheduled_cached_reqs,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||
# finished_req_ids is an existing state in the scheduler,
|
||||
# instead of being newly scheduled in this step.
|
||||
# It contains the request IDs that are finished in between
|
||||
# the previous and the current steps.
|
||||
finished_req_ids=self.finished_req_ids, # type: ignore
|
||||
free_encoder_input_ids=self.encoder_cache_manager.
|
||||
get_freed_ids(),
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
else:
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=new_reqs_data,
|
||||
scheduled_cached_reqs=scheduled_cached_reqs,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||
# finished_req_ids is an existing state in the scheduler,
|
||||
# instead of being newly scheduled in this step.
|
||||
# It contains the request IDs that are finished in between
|
||||
# the previous and the current steps.
|
||||
finished_req_ids=self.finished_req_ids, # type: ignore
|
||||
free_encoder_mm_hashes=self.encoder_cache_manager.
|
||||
get_freed_mm_hashes(),
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
# NOTE(Kuntai): this function is designed for multiple purposes:
|
||||
# 1. Plan the KV cache store
|
||||
# 2. Wrap up all the KV cache load / save ops into an opaque object
|
||||
# 3. Clear the internal states of the connector
|
||||
if self.connector is not None:
|
||||
meta = self.connector.build_connector_meta(scheduler_output)
|
||||
scheduler_output.kv_connector_metadata = meta
|
||||
|
||||
events = self.kv_cache_manager.take_events()
|
||||
if events:
|
||||
batch = KVEventBatch(ts=time.time(), events=events)
|
||||
self.kv_event_publisher.publish(batch)
|
||||
|
||||
# Advance the number of computed tokens for the request AFTER
|
||||
# the request is scheduled.
|
||||
# 1. The scheduler_output of the current step has to include the
|
||||
# original number of scheduled tokens to determine input IDs.
|
||||
# 2. Advance the number of computed tokens here allowing us to
|
||||
# schedule the prefill request again immediately in the next
|
||||
# scheduling step.
|
||||
# 3. If some tokens (e.g. spec tokens) are rejected later, the number of
|
||||
# computed tokens will be adjusted in update_from_output.
|
||||
for req_id, num_scheduled_token in num_scheduled_tokens.items():
|
||||
self.requests[req_id].num_computed_tokens += num_scheduled_token
|
||||
|
||||
self.finished_req_ids = set() # type: ignore
|
||||
return scheduler_output
|
||||
|
||||
def _check_watermark_for_prefill(self,
|
||||
request,
|
||||
num_new_tokens,
|
||||
computed_blocks,
|
||||
watermark=0.01):
|
||||
computed_blocks = computed_blocks or []
|
||||
watermark_blocks = self.kv_cache_config.num_blocks * watermark
|
||||
num_computed_tokens = (request.num_computed_tokens +
|
||||
len(computed_blocks) * self.block_size)
|
||||
num_required_blocks = cdiv(num_new_tokens + num_computed_tokens,
|
||||
self.block_size)
|
||||
req_blocks = self.kv_cache_manager.coordinator.get_blocks(
|
||||
request.request_id)
|
||||
num_new_blocks = (num_required_blocks - len(req_blocks[0]) -
|
||||
len(computed_blocks))
|
||||
num_evictable_computed_blocks = sum(1 for blk in computed_blocks
|
||||
if blk.ref_cnt == 0)
|
||||
# If number of free blocks is less than water mark after allocating, don't allocate.
|
||||
if (self.kv_cache_manager.block_pool.get_num_free_blocks() -
|
||||
num_evictable_computed_blocks -
|
||||
num_new_blocks) < watermark_blocks:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_prompt_limit(self, request: Request) -> int:
|
||||
if (self.scheduler_config.chunked_prefill_enabled
|
||||
and not self.scheduler_config.is_multi_step):
|
||||
prompt_limit = self.scheduler_config.max_model_len
|
||||
else:
|
||||
prompt_limit = min(
|
||||
self.scheduler_config.max_model_len,
|
||||
self.scheduler_config.max_num_batched_tokens,
|
||||
)
|
||||
|
||||
# Model is fine tuned with long context. Return the fine tuned max_len.
|
||||
if request.lora_request and request.lora_request.long_lora_max_len:
|
||||
assert prompt_limit <= request.lora_request.long_lora_max_len
|
||||
return request.lora_request.long_lora_max_len
|
||||
else:
|
||||
return prompt_limit
|
||||
|
||||
def finish_requests(
|
||||
self,
|
||||
request_ids: Union[str, Iterable[str]],
|
||||
finished_status: RequestStatus,
|
||||
) -> None:
|
||||
"""Handles the finish signal from outside the scheduler.
|
||||
|
||||
For example, the API server can abort a request when the client
|
||||
disconnects.
|
||||
"""
|
||||
for req_id in request_ids:
|
||||
request = self.requests.get(req_id)
|
||||
if request is None:
|
||||
# Invalid request ID.
|
||||
continue
|
||||
if request.status == RequestStatus.RUNNING:
|
||||
self.scheduled_req_ids.discard(request.request_id)
|
||||
super().finish_requests(request_ids, finished_status)
|
||||
|
||||
def update_from_output(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
model_runner_output: ModelRunnerOutput,
|
||||
) -> EngineCoreOutputs:
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
|
||||
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
|
||||
# loop can be a performance bottleneck. We should do our best to avoid
|
||||
# expensive operations inside the loop.
|
||||
for request in self.running:
|
||||
req_id = request.request_id
|
||||
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
|
||||
if num_tokens_scheduled == 0:
|
||||
# The request was not scheduled in this step.
|
||||
continue
|
||||
if req_id in self.scheduled_req_ids:
|
||||
self.scheduled_req_ids.remove(req_id)
|
||||
|
||||
return super().update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
0
vllm_ascend/device_allocator/__init__.py
Normal file
0
vllm_ascend/device_allocator/__init__.py
Normal file
278
vllm_ascend/device_allocator/camem.py
Normal file
278
vllm_ascend/device_allocator/camem.py
Normal file
@@ -0,0 +1,278 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# CANN-mem-based pytorch pluggable allocator to implement sleep mode.
|
||||
#
|
||||
import dataclasses
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from acl.rt import memcpy # type: ignore # noqa: F401
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
|
||||
|
||||
def find_loaded_library(lib_name) -> Optional[str]:
|
||||
"""
|
||||
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
|
||||
the file `/proc/self/maps` contains the memory maps of the process, which includes the
|
||||
shared libraries loaded by the process. We can use this file to find the path of the
|
||||
a loaded library.
|
||||
""" # noqa
|
||||
found_line = None
|
||||
with open("/proc/self/maps") as f:
|
||||
for line in f:
|
||||
if lib_name in line:
|
||||
found_line = line
|
||||
break
|
||||
if found_line is None:
|
||||
# the library is not loaded in the current process
|
||||
return None
|
||||
# if lib_name is libcudart, we need to match a line with:
|
||||
# address /path/to/libcudart-hash.so.11.0
|
||||
start = found_line.index("/")
|
||||
path = found_line[start:].strip()
|
||||
filename = path.split("/")[-1]
|
||||
assert filename.rpartition(".so")[0].startswith(lib_name), \
|
||||
f"Unexpected filename: {filename} for library {lib_name}"
|
||||
return path
|
||||
|
||||
|
||||
camem_available = False
|
||||
try:
|
||||
from vllm_ascend.vllm_ascend_C import ( # type: ignore # noqa: F401
|
||||
init_module, python_create_and_map, python_unmap_and_release)
|
||||
lib_name = find_loaded_library("vllm_ascend_C")
|
||||
camem_available = True
|
||||
except ImportError as e:
|
||||
logger.warning(
|
||||
"Failed to import vllm_ascend_C:%s. Sleep mode will be disabled. ", e)
|
||||
init_module = None
|
||||
python_create_and_map = None
|
||||
python_unmap_and_release = None
|
||||
lib_name = None
|
||||
libcudart = None
|
||||
|
||||
# py_device, py_alignedSize, py_d_mem, py_p_memHandle
|
||||
HandleType = Tuple[int, int, int, int]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AllocationData:
|
||||
handle: HandleType
|
||||
tag: str
|
||||
cpu_backup_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def create_and_map(allocation_handle: HandleType) -> None:
|
||||
python_create_and_map(*allocation_handle)
|
||||
|
||||
|
||||
def unmap_and_release(allocation_handle: HandleType) -> None:
|
||||
python_unmap_and_release(*allocation_handle)
|
||||
|
||||
|
||||
def get_pluggable_allocator(
|
||||
python_malloc_fn: Callable[[tuple[int, int, int, int]], None],
|
||||
python_free_func: Callable[[int], tuple[int, int, int, int]]
|
||||
) -> torch.npu.memory.NPUPluggableAllocator:
|
||||
init_module(python_malloc_fn, python_free_func)
|
||||
new_alloc = torch.npu.memory.NPUPluggableAllocator(lib_name, 'my_malloc',
|
||||
'my_free')
|
||||
return new_alloc
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_memory_pool_with_allocator(
|
||||
python_malloc_fn: Callable[[tuple[int, int, int, int]], None],
|
||||
python_free_func: Callable[[int], tuple[int, int, int, int]]):
|
||||
new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
|
||||
mem_pool = torch.npu.memory.MemPool(new_alloc._allocator)
|
||||
with torch.npu.memory.use_mem_pool(mem_pool):
|
||||
yield mem_pool, new_alloc
|
||||
|
||||
|
||||
class CaMemAllocator:
|
||||
"""
|
||||
A singleton class that manages a memory pool for CANN tensors.
|
||||
The memory in this pool can be offloaded or discarded when the
|
||||
allocator sleeps.
|
||||
Inside the `use_memory_pool(tag)` context, all tensors created will
|
||||
be allocated in the memory pool, and has the same tag as the
|
||||
tag passed to the context.
|
||||
When we call `sleep`, all tensors with the specified tag will be
|
||||
offloaded to CPU memory, and the rest of the tensors will be discarded.
|
||||
When we call `wake_up`, all tensors that are previously offloaded
|
||||
will be loaded back to GPU memory, and the rest of the tensors will
|
||||
have empty memory.
|
||||
Why it needs to be a singleton?
|
||||
When allocated tensors are garbage collected, PyTorch will call
|
||||
the free callback, which will call the `python_free_callback` method.
|
||||
The C-extension uses a global variable to store the function of an
|
||||
instance of this class. If we create multiple instances of this class,
|
||||
the global variable will be overwritten and the free callback will
|
||||
not work as expected.
|
||||
"""
|
||||
instance = None
|
||||
default_tag: str = "default"
|
||||
|
||||
@staticmethod
|
||||
def get_instance() -> "CaMemAllocator":
|
||||
"""
|
||||
CaMemAllocator is a singleton class.
|
||||
We cannot call the constructor directly.
|
||||
Call this method to get the instance.
|
||||
"""
|
||||
if CaMemAllocator.instance is None:
|
||||
CaMemAllocator.instance = CaMemAllocator()
|
||||
return CaMemAllocator.instance
|
||||
|
||||
def __init__(self):
|
||||
conf = os.environ.get("PYTORCH_NPU_ALLOC_CONF", "")
|
||||
assert "expandable_segments:True" not in conf, \
|
||||
("Expandable segments are not compatible with memory pool. "
|
||||
"Please track https://github.com/pytorch/pytorch/issues/147851 "
|
||||
"for the latest updates.")
|
||||
|
||||
self.pointer_to_data: Dict[int, AllocationData] = {}
|
||||
self.current_tag: str = CaMemAllocator.default_tag
|
||||
self.allocator_and_pools: Dict[str, Any] = {}
|
||||
|
||||
def python_malloc_callback(self, allocation_handle: HandleType) -> None:
|
||||
"""
|
||||
Internal method to store the allocation data
|
||||
when memory is allocated in the memory pool."""
|
||||
py_d_mem = allocation_handle[2]
|
||||
self.pointer_to_data[py_d_mem] = AllocationData(
|
||||
allocation_handle, self.current_tag)
|
||||
return
|
||||
|
||||
def python_free_callback(self, ptr: int) -> HandleType:
|
||||
"""
|
||||
Internal method to look up the allocation data
|
||||
when memory is freed in the memory pool."""
|
||||
data = self.pointer_to_data.pop(ptr)
|
||||
if data.cpu_backup_tensor is not None:
|
||||
data.cpu_backup_tensor = None
|
||||
return data.handle
|
||||
|
||||
def sleep(
|
||||
self,
|
||||
offload_tags: Optional[Union[Tuple[str, ...],
|
||||
str]] = None) -> None:
|
||||
"""
|
||||
Put the allocator in sleep mode.
|
||||
All data in the memory allocation with the specified tag will be
|
||||
offloaded to CPU memory, and others will be discarded.
|
||||
:param offload_tags: The tags of the memory allocation that will be
|
||||
offloaded. The rest of the memory allocation will be discarded.
|
||||
"""
|
||||
if offload_tags is None:
|
||||
# by default, allocated tensors are offloaded
|
||||
# when the allocator sleeps
|
||||
offload_tags = (CaMemAllocator.default_tag, )
|
||||
elif isinstance(offload_tags, str):
|
||||
offload_tags = (offload_tags, )
|
||||
|
||||
assert isinstance(offload_tags, tuple)
|
||||
|
||||
for ptr, data in self.pointer_to_data.items():
|
||||
handle = data.handle
|
||||
if data.tag in offload_tags:
|
||||
size_in_bytes = handle[1]
|
||||
cpu_backup_tensor = torch.empty(
|
||||
size_in_bytes,
|
||||
dtype=torch.uint8,
|
||||
device='cpu',
|
||||
pin_memory=NPUPlatform.is_pin_memory_available())
|
||||
cpu_ptr = cpu_backup_tensor.data_ptr()
|
||||
ACL_MEMCPY_DEVICE_TO_HOST = 2
|
||||
dest_max = cpu_ptr + size_in_bytes * 2
|
||||
memcpy(cpu_ptr, dest_max, ptr, size_in_bytes,
|
||||
ACL_MEMCPY_DEVICE_TO_HOST)
|
||||
data.cpu_backup_tensor = cpu_backup_tensor
|
||||
unmap_and_release(handle)
|
||||
|
||||
def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
||||
"""
|
||||
Wake up the allocator from sleep mode.
|
||||
All data that is previously offloaded will be loaded back to GPU
|
||||
memory, and the rest of the data will have empty memory."""
|
||||
for ptr, data in self.pointer_to_data.items():
|
||||
if tags is None or data.tag in tags:
|
||||
handle = data.handle
|
||||
create_and_map(handle)
|
||||
if data.cpu_backup_tensor is not None:
|
||||
cpu_backup_tensor = data.cpu_backup_tensor
|
||||
if cpu_backup_tensor is not None:
|
||||
size_in_bytes = cpu_backup_tensor.numel(
|
||||
) * cpu_backup_tensor.element_size()
|
||||
cpu_ptr = cpu_backup_tensor.data_ptr()
|
||||
ACL_MEMCPY_HOST_TO_DEVICE = 1
|
||||
dest_max = ptr + size_in_bytes * 2
|
||||
memcpy(ptr, dest_max, cpu_ptr, size_in_bytes,
|
||||
ACL_MEMCPY_HOST_TO_DEVICE)
|
||||
data.cpu_backup_tensor = None
|
||||
|
||||
@contextmanager
|
||||
def use_memory_pool(self, tag: Optional[str] = None):
|
||||
"""
|
||||
A context manager to use the memory pool.
|
||||
All memory allocation created inside the context will be allocated
|
||||
in the memory pool, and has the specified tag.
|
||||
:param tag: The tag of the memory allocation. If None, the default tag
|
||||
will be used.
|
||||
"""
|
||||
if tag is None:
|
||||
tag = CaMemAllocator.default_tag
|
||||
|
||||
assert isinstance(tag, str)
|
||||
|
||||
old_tag = self.current_tag
|
||||
self.current_tag = tag
|
||||
with use_memory_pool_with_allocator(self.python_malloc_callback,
|
||||
self.python_free_callback) as data:
|
||||
# start to hit another PyTorch bug in PyTorch 2.6,
|
||||
# possibly because of gc-related issue w.r.t. the allocator and
|
||||
# the memory pool.
|
||||
# to avoid the issue, we keep a reference of the data.
|
||||
# see https://github.com/pytorch/pytorch/issues/146431 .
|
||||
self.allocator_and_pools[tag] = data
|
||||
yield
|
||||
# PyTorch's bug, calling torch.cuda.empty_cache() will error
|
||||
# when using pluggable allocator, see
|
||||
# https://github.com/pytorch/pytorch/issues/145168 .
|
||||
# if we have some memory allocated and then freed,
|
||||
# the memory will not be released.
|
||||
# right now it is fine, because we only use this allocator
|
||||
# during weight loading and kv cache creation, where we only
|
||||
# allocate memory.
|
||||
# TODO: we need to find a way to release the memory,
|
||||
# i.e. calling torch.cuda.empty_cache()
|
||||
self.current_tag = old_tag
|
||||
|
||||
def get_current_usage(self) -> int:
|
||||
"""
|
||||
Get the total number of bytes allocated in the memory pool.
|
||||
"""
|
||||
sum_bytes: int = 0
|
||||
for ptr, data in self.pointer_to_data.items():
|
||||
handle = data.handle
|
||||
sum_bytes += handle[1]
|
||||
return sum_bytes
|
||||
28
vllm_ascend/distributed/__init__.py
Normal file
28
vllm_ascend/distributed/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import \
|
||||
KVConnectorFactory
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"LLMDataDistCMgrConnector",
|
||||
"vllm_ascend.distributed.llmdatadist_c_mgr_connector",
|
||||
"LLMDataDistCMgrConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeConnectorV1", "vllm_ascend.distributed.mooncake_connector",
|
||||
"MooncakeConnector")
|
||||
25
vllm_ascend/distributed/communication_op.py
Normal file
25
vllm_ascend/distributed/communication_op.py
Normal file
@@ -0,0 +1,25 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
|
||||
|
||||
def data_parallel_reduce_scatter(input_: torch.Tensor,
|
||||
dim: int = -1) -> torch.Tensor:
|
||||
"""Reduce-Scatter the input tensor across data parallel group."""
|
||||
return get_dp_group().reduce_scatter(input_, dim)
|
||||
75
vllm_ascend/distributed/communicator.py
Normal file
75
vllm_ascend/distributed/communicator.py
Normal file
@@ -0,0 +1,75 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from vllm.distributed.device_communicators.base_device_communicator import \
|
||||
DeviceCommunicatorBase
|
||||
|
||||
|
||||
class NPUCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def __init__(self,
|
||||
cpu_group: dist.ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[dist.ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
# TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator
|
||||
# init device according to rank
|
||||
self.device = torch.npu.current_device()
|
||||
|
||||
def all_to_all(self,
|
||||
input_: torch.Tensor,
|
||||
scatter_dim: int = 0,
|
||||
gather_dim: int = -1,
|
||||
scatter_sizes: Optional[List[int]] = None,
|
||||
gather_sizes: Optional[List[int]] = None) -> torch.Tensor:
|
||||
|
||||
if scatter_dim < 0:
|
||||
scatter_dim += input_.dim()
|
||||
if gather_dim < 0:
|
||||
gather_dim += input_.dim()
|
||||
|
||||
if scatter_sizes is not None and gather_sizes is not None:
|
||||
input_list = [
|
||||
t.contiguous()
|
||||
for t in torch.split(input_, scatter_sizes, scatter_dim)
|
||||
]
|
||||
output_list = []
|
||||
tensor_shape_base = input_list[self.rank].size()
|
||||
for i in range(self.world_size):
|
||||
tensor_shape = list(tensor_shape_base)
|
||||
tensor_shape[gather_dim] = gather_sizes[i]
|
||||
output_list.append(
|
||||
torch.empty(tensor_shape,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device))
|
||||
|
||||
else:
|
||||
input_list = [
|
||||
t.contiguous() for t in torch.tensor_split(
|
||||
input_, self.world_size, scatter_dim)
|
||||
]
|
||||
output_list = [
|
||||
torch.empty_like(input_list[i]) for i in range(self.world_size)
|
||||
]
|
||||
|
||||
dist.all_to_all(output_list, input_list, group=self.device_group)
|
||||
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()
|
||||
return output_tensor
|
||||
165
vllm_ascend/distributed/device_communicators/pyhccl.py
Normal file
165
vllm_ascend/distributed/device_communicators/pyhccl.py
Normal file
@@ -0,0 +1,165 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import (
|
||||
HCCLLibrary, aclrtStream_t, buffer_type, hcclComm_t, hcclDataTypeEnum,
|
||||
hcclRedOpTypeEnum, hcclUniqueId)
|
||||
from vllm_ascend.utils import current_stream
|
||||
|
||||
|
||||
class PyHcclCommunicator:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group: Union[ProcessGroup, StatelessProcessGroup],
|
||||
device: Union[int, str, torch.device],
|
||||
library_path: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the PyHcclCommunicator to. If None,
|
||||
it will be bind to f"npu:{local_rank}".
|
||||
library_path: the path to the HCCL library. If None, it will
|
||||
use the default library path.
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device.
|
||||
"""
|
||||
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
assert dist.is_initialized()
|
||||
assert dist.get_backend(group) != dist.Backend.HCCL, (
|
||||
"PyHcclCommunicator should be attached to a non-HCCL group.")
|
||||
# note: this rank is the rank in the group
|
||||
self.rank = dist.get_rank(group)
|
||||
self.world_size = dist.get_world_size(group)
|
||||
else:
|
||||
self.rank = group.rank
|
||||
self.world_size = group.world_size
|
||||
|
||||
self.group = group
|
||||
|
||||
# if world_size == 1, no need to create communicator
|
||||
if self.world_size == 1:
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
return
|
||||
|
||||
try:
|
||||
self.hccl = HCCLLibrary(library_path)
|
||||
except Exception:
|
||||
# disable because of missing HCCL library
|
||||
# e.g. in a non-NPU environment
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
return
|
||||
|
||||
self.available = True
|
||||
self.disabled = False
|
||||
|
||||
logger.info("vLLM is using pyhccl")
|
||||
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"npu:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
|
||||
if self.rank == 0:
|
||||
# get the unique id from HCCL
|
||||
with torch.npu.device(device):
|
||||
self.unique_id = self.hccl.hcclGetUniqueId()
|
||||
else:
|
||||
# construct an empty unique id
|
||||
self.unique_id = hcclUniqueId()
|
||||
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
||||
ranks = dist.get_process_group_ranks(group)
|
||||
# arg `src` in `broadcast` is the global rank
|
||||
dist.broadcast(tensor, src=ranks[0], group=group)
|
||||
byte_list = tensor.tolist()
|
||||
for i, byte in enumerate(byte_list):
|
||||
self.unique_id.internal[i] = byte
|
||||
else:
|
||||
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
|
||||
|
||||
# hccl communicator and stream will use this device
|
||||
# `torch.npu.device` is a context manager that changes the
|
||||
# current npu device to the specified one
|
||||
with torch.npu.device(device):
|
||||
self.comm: hcclComm_t = self.hccl.hcclCommInitRank(
|
||||
self.world_size, self.unique_id, self.rank)
|
||||
|
||||
stream = current_stream()
|
||||
# A small all_reduce for warmup.
|
||||
data = torch.zeros(1, device=device)
|
||||
self.all_reduce(data)
|
||||
stream.synchronize()
|
||||
del data
|
||||
|
||||
def all_reduce(self,
|
||||
in_tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None) -> torch.Tensor:
|
||||
if self.disabled:
|
||||
return None
|
||||
# hccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert in_tensor.device == self.device, (
|
||||
f"this hccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {in_tensor.device}")
|
||||
|
||||
out_tensor = torch.empty_like(in_tensor)
|
||||
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.hccl.hcclAllReduce(buffer_type(in_tensor.data_ptr()),
|
||||
buffer_type(out_tensor.data_ptr()),
|
||||
in_tensor.numel(),
|
||||
hcclDataTypeEnum.from_torch(in_tensor.dtype),
|
||||
hcclRedOpTypeEnum.from_torch(op), self.comm,
|
||||
aclrtStream_t(stream.npu_stream))
|
||||
return out_tensor
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this hccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
if src == self.rank:
|
||||
buffer = buffer_type(tensor.data_ptr())
|
||||
else:
|
||||
buffer = buffer_type(tensor.data_ptr())
|
||||
self.hccl.hcclBroadcast(buffer, tensor.numel(),
|
||||
hcclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||
self.comm, aclrtStream_t(stream.npu_stream))
|
||||
253
vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py
Normal file
253
vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py
Normal file
@@ -0,0 +1,253 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import ctypes
|
||||
import platform
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.utils import find_hccl_library
|
||||
|
||||
# export types and functions from hccl to Python ===
|
||||
# for the original hccl definition, please check
|
||||
# https://github.com/EternalLied/cann-hccl-new/blob/64ec6ce2923319caa5df8c3c531e06bdc148ce9c/inc/hccl/hccl.h#L90
|
||||
# https://github.com/EternalLied/cann-hccl-new/blob/64ec6ce2923319caa5df8c3c531e06bdc148ce9c/inc/hccl/hccl_types.h#L48
|
||||
|
||||
hcclResult_t = ctypes.c_int
|
||||
hcclComm_t = ctypes.c_void_p
|
||||
|
||||
|
||||
class hcclUniqueId(ctypes.Structure):
|
||||
_fields_ = [("internal", ctypes.c_byte * 4108)]
|
||||
|
||||
|
||||
aclrtStream_t = ctypes.c_void_p
|
||||
buffer_type = ctypes.c_void_p
|
||||
|
||||
hcclDataType_t = ctypes.c_int
|
||||
|
||||
|
||||
class hcclDataTypeEnum:
|
||||
hcclInt8 = 0
|
||||
hcclInt16 = 1
|
||||
hcclInt32 = 2
|
||||
hcclFloat16 = 3
|
||||
hcclFloat32 = 4
|
||||
hcclInt64 = 5
|
||||
hcclUint64 = 6
|
||||
hcclUint8 = 7
|
||||
hcclUint16 = 8
|
||||
hcclUint32 = 9
|
||||
hcclFloat64 = 10
|
||||
hcclBfloat16 = 11
|
||||
hcclInt128 = 12
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, dtype: torch.dtype) -> int:
|
||||
if dtype == torch.int8:
|
||||
return cls.hcclInt8
|
||||
if dtype == torch.uint8:
|
||||
return cls.hcclUint8
|
||||
if dtype == torch.int32:
|
||||
return cls.hcclInt32
|
||||
if dtype == torch.int64:
|
||||
return cls.hcclInt64
|
||||
if dtype == torch.float16:
|
||||
return cls.hcclFloat16
|
||||
if dtype == torch.float32:
|
||||
return cls.hcclFloat32
|
||||
if dtype == torch.float64:
|
||||
return cls.hcclFloat64
|
||||
if dtype == torch.bfloat16:
|
||||
return cls.hcclBfloat16
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
hcclRedOp_t = ctypes.c_int
|
||||
|
||||
|
||||
class hcclRedOpTypeEnum:
|
||||
hcclSum = 0
|
||||
hcclProd = 1
|
||||
hcclMax = 2
|
||||
hcclMin = 3
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, op: ReduceOp) -> int:
|
||||
if op == ReduceOp.SUM:
|
||||
return cls.hcclSum
|
||||
if op == ReduceOp.PRODUCT:
|
||||
return cls.hcclProd
|
||||
if op == ReduceOp.MAX:
|
||||
return cls.hcclMax
|
||||
if op == ReduceOp.MIN:
|
||||
return cls.hcclMin
|
||||
raise ValueError(f"Unsupported op: {op}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: List[Any]
|
||||
|
||||
|
||||
class HCCLLibrary:
|
||||
exported_functions = [
|
||||
# const char* HcclGetErrorString(HcclResult code);
|
||||
Function("HcclGetErrorString", ctypes.c_char_p, [hcclResult_t]),
|
||||
|
||||
# HcclResult HcclGetRootInfo(HcclRootInfo *rootInfo);
|
||||
Function("HcclGetRootInfo", hcclResult_t,
|
||||
[ctypes.POINTER(hcclUniqueId)]),
|
||||
|
||||
# HcclResult HcclCommInitRootInfo(
|
||||
# uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm);
|
||||
# note that HcclComm is a pointer type, so the last argument is a pointer to a pointer
|
||||
Function("HcclCommInitRootInfo", hcclResult_t, [
|
||||
ctypes.c_int,
|
||||
ctypes.POINTER(hcclUniqueId),
|
||||
ctypes.c_int,
|
||||
ctypes.POINTER(hcclComm_t),
|
||||
]),
|
||||
|
||||
# HcclResult HcclAllReduce(
|
||||
# void *sendBuf, void *recvBuf, uint64_t count,
|
||||
# HcclDataType dataType, HcclReduceOp op, HcclComm comm,
|
||||
# aclrtStream stream);
|
||||
Function("HcclAllReduce", hcclResult_t, [
|
||||
buffer_type,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
hcclDataType_t,
|
||||
hcclRedOp_t,
|
||||
hcclComm_t,
|
||||
aclrtStream_t,
|
||||
]),
|
||||
|
||||
# HcclResult HcclBroadcast(
|
||||
# void *buf, uint64_t count,
|
||||
# HcclDataType dataType, uint32_t root,
|
||||
# HcclComm comm, aclrtStream stream);
|
||||
Function("HcclBroadcast", hcclResult_t, [
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
hcclDataType_t,
|
||||
ctypes.c_int,
|
||||
hcclComm_t,
|
||||
aclrtStream_t,
|
||||
]),
|
||||
|
||||
# HcclResult HcclCommDestroy(HcclComm comm);
|
||||
Function("HcclCommDestroy", hcclResult_t, [hcclComm_t]),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: Dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the correspongding directory
|
||||
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
|
||||
so_file = so_file or find_hccl_library()
|
||||
|
||||
try:
|
||||
if so_file not in HCCLLibrary.path_to_dict_mapping:
|
||||
lib = ctypes.CDLL(so_file)
|
||||
HCCLLibrary.path_to_library_cache[so_file] = lib
|
||||
self.lib = HCCLLibrary.path_to_library_cache[so_file]
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to load HCCL library from %s. "
|
||||
"It is expected if you are not running on Ascend NPUs."
|
||||
"Otherwise, the hccl library might not exist, be corrupted "
|
||||
"or it does not support the current platform %s. "
|
||||
"If you already have the library, please set the "
|
||||
"environment variable HCCL_SO_PATH"
|
||||
" to point to the correct hccl library path.", so_file,
|
||||
platform.platform())
|
||||
raise e
|
||||
|
||||
if so_file not in HCCLLibrary.path_to_dict_mapping:
|
||||
_funcs: Dict[str, Any] = {}
|
||||
for func in HCCLLibrary.exported_functions:
|
||||
f = getattr(self.lib, func.name)
|
||||
f.restype = func.restype
|
||||
f.argtypes = func.argtypes
|
||||
_funcs[func.name] = f
|
||||
HCCLLibrary.path_to_dict_mapping[so_file] = _funcs
|
||||
self._funcs = HCCLLibrary.path_to_dict_mapping[so_file]
|
||||
|
||||
def hcclGetErrorString(self, result: hcclResult_t) -> str:
|
||||
return self._funcs["HcclGetErrorString"](result).decode("utf-8")
|
||||
|
||||
def HCCL_CHECK(self, result: hcclResult_t) -> None:
|
||||
if result != 0:
|
||||
error_str = self.hcclGetErrorString(result)
|
||||
raise RuntimeError(f"HCCL error: {error_str}")
|
||||
|
||||
def hcclGetUniqueId(self) -> hcclUniqueId:
|
||||
unique_id = hcclUniqueId()
|
||||
self.HCCL_CHECK(self._funcs["HcclGetRootInfo"](
|
||||
ctypes.byref(unique_id)))
|
||||
return unique_id
|
||||
|
||||
def hcclCommInitRank(self, world_size: int, unique_id: hcclUniqueId,
|
||||
rank: int) -> hcclComm_t:
|
||||
comm = hcclComm_t()
|
||||
self.HCCL_CHECK(self._funcs["HcclCommInitRootInfo"](
|
||||
world_size, ctypes.byref(unique_id), rank, ctypes.byref(comm)))
|
||||
return comm
|
||||
|
||||
def hcclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, op: int, comm: hcclComm_t,
|
||||
stream: aclrtStream_t) -> None:
|
||||
# `datatype` actually should be `hcclDataType_t`
|
||||
# and `op` should be `hcclRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.HCCL_CHECK(self._funcs["HcclAllReduce"](sendbuff, recvbuff, count,
|
||||
datatype, op, comm,
|
||||
stream))
|
||||
|
||||
def hcclBroadcast(self, buf: buffer_type, count: int, datatype: int,
|
||||
root: int, comm: hcclComm_t,
|
||||
stream: aclrtStream_t) -> None:
|
||||
self.HCCL_CHECK(self._funcs["HcclBroadcast"](buf, count, datatype,
|
||||
root, comm, stream))
|
||||
|
||||
def hcclCommDestroy(self, comm: hcclComm_t) -> None:
|
||||
self.HCCL_CHECK(self._funcs["HcclCommDestroy"](comm))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"HCCLLibrary",
|
||||
"hcclDataTypeEnum",
|
||||
"hcclRedOpTypeEnum",
|
||||
"hcclUniqueId",
|
||||
"hcclComm_t",
|
||||
"aclrtStream_t",
|
||||
"buffer_type",
|
||||
]
|
||||
894
vllm_ascend/distributed/llmdatadist_c_mgr_connector.py
Normal file
894
vllm_ascend/distributed/llmdatadist_c_mgr_connector.py
Normal file
@@ -0,0 +1,894 @@
|
||||
import contextlib
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
|
||||
import llm_datadist # type: ignore
|
||||
import msgspec
|
||||
import torch
|
||||
import zmq
|
||||
from llm_datadist import (BlocksCacheKey, CacheDesc, LLMConfig, LLMDataDist,
|
||||
LLMException, LLMRole)
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.parallel_state import get_tp_group, get_world_group
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.utils import get_ip, logger
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
|
||||
|
||||
TORCH_DTYPE_TO_NPU_DTYPE = {
|
||||
torch.half: llm_datadist.DataType.DT_FLOAT16,
|
||||
torch.float16: llm_datadist.DataType.DT_FLOAT16,
|
||||
torch.bfloat16: llm_datadist.DataType.DT_BF16,
|
||||
torch.float: llm_datadist.DataType.DT_FLOAT,
|
||||
torch.float32: llm_datadist.DataType.DT_FLOAT,
|
||||
torch.int8: llm_datadist.DataType.DT_INT8,
|
||||
torch.int64: llm_datadist.DataType.DT_INT64,
|
||||
torch.int32: llm_datadist.DataType.DT_INT32
|
||||
}
|
||||
|
||||
|
||||
class LLMDataDistCMgrEvent(Enum):
|
||||
ReqForMetadata = 0
|
||||
ReqForFinished = 1
|
||||
|
||||
|
||||
class LLMDataDistCMgrAgentMetadata(msgspec.Struct):
|
||||
super_pod_id: str
|
||||
server_id: str
|
||||
device_id: str
|
||||
device_ip: str
|
||||
super_device_id: str
|
||||
cluster_id: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
local_block_ids: list[int]
|
||||
remote_block_ids: list[int]
|
||||
remote_host: str
|
||||
remote_port: str
|
||||
engine_id: str
|
||||
remote_tp_size: str
|
||||
|
||||
|
||||
class LLMDataDistCMgrConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
def __init__(self):
|
||||
self.requests: dict[str, ReqMeta] = {}
|
||||
|
||||
def add_new_req(self, request_id: str, local_block_ids: list[int],
|
||||
kv_transfer_params: dict[str, Any]):
|
||||
self.requests[request_id] = ReqMeta(
|
||||
local_block_ids=local_block_ids,
|
||||
remote_block_ids=kv_transfer_params["remote_block_ids"],
|
||||
engine_id=kv_transfer_params["remote_engine_id"],
|
||||
remote_host=kv_transfer_params["remote_host"],
|
||||
remote_port=kv_transfer_params["remote_port"],
|
||||
remote_tp_size=kv_transfer_params["remote_tp_size"],
|
||||
)
|
||||
|
||||
|
||||
class LLMDataDistCMgrConnector(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
self.engine_id = vllm_config.kv_transfer_config.engine_id
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler: Optional[
|
||||
LLMDataDistCMgrConnectorScheduler] = LLMDataDistCMgrConnectorScheduler(
|
||||
vllm_config, self.engine_id)
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_scheduler = None
|
||||
self.connector_worker = LLMDataDistCMgrConnectorWorker(vllm_config)
|
||||
|
||||
############################################################
|
||||
# Scheduler Side Methods
|
||||
############################################################
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
|
||||
############################################################
|
||||
# Worker Side Methods
|
||||
############################################################
|
||||
def register_kv_caches(
|
||||
self,
|
||||
kv_caches: dict[
|
||||
str, # type: ignore[override]
|
||||
Tuple[torch.Tensor]]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
"""Get the finished recving and sending requests."""
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_finished(finished_req_ids)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata,
|
||||
LLMDataDistCMgrConnectorMetadata)
|
||||
self.connector_worker.start_load_kv(self._connector_metadata)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""LLMDataDistCMgrConnector does not do layerwise saving, the load is in blocking manager."""
|
||||
pass
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata, **kwargs) -> None:
|
||||
"""LLMDataDistCMgrConnector does not save explicitly."""
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
"""LLMDataDistCMgrConnector does not save explicitly."""
|
||||
pass
|
||||
|
||||
|
||||
class LLMDataDistCMgrConnectorScheduler():
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: Optional[str]):
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.engine_id = engine_id
|
||||
self.local_ip = get_ip()
|
||||
# Can not retrieve the parallel config since it is not initialized.
|
||||
self.local_dp_rank = None
|
||||
self.tp_size = None
|
||||
dp_rank_local = self.vllm_config.parallel_config.data_parallel_rank_local
|
||||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
|
||||
self.port = dp_rank_local * tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT
|
||||
|
||||
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
"""
|
||||
For remote prefill, pull all prompt blocks from remote
|
||||
asynchronously relative to engine execution.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
Returns:
|
||||
* the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
* true if the external KV cache tokens will be loaded
|
||||
asynchronously (between scheduler steps).
|
||||
"""
|
||||
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
f"LLMDataDistCMgrConnector get_num_new_matched_tokens: num_computed_tokens={num_computed_tokens}, kv_transfer_params={params}"
|
||||
)
|
||||
|
||||
if params is not None and params.get("do_remote_prefill"):
|
||||
# Remote prefill: get all prompt blocks from remote.
|
||||
assert num_computed_tokens % self.block_size == 0
|
||||
# Note: We use the full token count as transmit data here.
|
||||
count = max(len(request.prompt_token_ids) - num_computed_tokens, 0)
|
||||
return count, count > 0
|
||||
|
||||
# No remote prefill for this request.
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks,
|
||||
num_externel_tokens: int):
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
f"LLMDataDistCMgrConnector update states num_externel_tokens: {num_externel_tokens} kv_transfer_params: {params}"
|
||||
)
|
||||
if params is not None and params.get("do_remote_prefill"):
|
||||
if params.get("remote_block_ids"):
|
||||
if all(p in params for p in ("remote_engine_id", "remote_host",
|
||||
"remote_port", "remote_tp_size")):
|
||||
self._reqs_need_recv[request.request_id] = (
|
||||
request, blocks.get_unhashed_block_ids())
|
||||
else:
|
||||
logger.warning("" \
|
||||
f"Invalid KVTransferParams {params}, This request will be discard")
|
||||
else:
|
||||
assert num_externel_tokens == 0
|
||||
params["do_remote_prefill"] = False
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
meta = LLMDataDistCMgrConnectorMetadata()
|
||||
|
||||
for req_id, (req, block_ids) in self._reqs_need_recv.items():
|
||||
assert req.kv_transfer_params is not None
|
||||
meta.add_new_req(request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
kv_transfer_params=req.kv_transfer_params)
|
||||
self._reqs_need_recv.clear()
|
||||
|
||||
return meta
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
"LLMDataDistCMgrConnector request_finished, request_status=%s, "
|
||||
"kv_transfer_params=%s", request.status, params)
|
||||
|
||||
if (params is None or not params.get("do_remote_decode")
|
||||
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
|
||||
return False, None
|
||||
|
||||
# note: NIXL transfer the full block only, but I don't see any reason to do that, so here
|
||||
# we just transfer any data that computed from prefill node
|
||||
# note: there might be some issue on this, check it if there is any unexpected result
|
||||
computed_block_ids = block_ids
|
||||
delay_free_blocks = len(computed_block_ids) > 0
|
||||
if delay_free_blocks:
|
||||
logger.info("Delaying free of %d blocks for request %s",
|
||||
len(computed_block_ids), request.request_id)
|
||||
return delay_free_blocks, dict(
|
||||
do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
remote_block_ids=computed_block_ids,
|
||||
remote_engine_id=self.engine_id,
|
||||
remote_host=self.local_ip,
|
||||
remote_port=self.port,
|
||||
remote_tp_size=str(
|
||||
self.vllm_config.parallel_config.tensor_parallel_size),
|
||||
)
|
||||
|
||||
|
||||
class LLMDataDistCMgrConnectorWorker():
|
||||
"""
|
||||
Implementation of Worker side methods
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
logger.info("Initialize the LLMDataDistCMgrConnectorWorker")
|
||||
# we assume the local node only contains dp and tp, and tp will not communicate inter-node.
|
||||
# for any scenario beyond this scope, the functionality of this connector is not guaranteed.
|
||||
self.local_rank_on_node = get_world_group().rank % (
|
||||
vllm_config.parallel_config.data_parallel_size_local *
|
||||
vllm_config.parallel_config.tensor_parallel_size)
|
||||
self.local_rank = get_world_group().local_rank
|
||||
self.local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
|
||||
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
self.rank = get_world_group().rank
|
||||
self.local_ip = get_ip()
|
||||
self.kv_transfer_config: KVTransferConfig = vllm_config.kv_transfer_config
|
||||
self.local_agent_metadata: Optional[
|
||||
LLMDataDistCMgrAgentMetadata] = None
|
||||
self.vllm_config = vllm_config
|
||||
self.executor = ThreadPoolExecutor(1)
|
||||
self.thread_lock = threading.Lock()
|
||||
|
||||
self.llm_datadist_role = None
|
||||
self.llm_datadist_remote_role = None
|
||||
if self.kv_transfer_config.kv_role == "kv_producer":
|
||||
self.llm_datadist_role = LLMRole.PROMPT
|
||||
self.llm_datadist_remote_role = LLMRole.DECODER
|
||||
elif self.kv_transfer_config.kv_role == "kv_consumer":
|
||||
self.llm_datadist_role = LLMRole.DECODER
|
||||
self.llm_datadist_remote_role = LLMRole.PROMPT
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistWorker: Receive unexpected kv role in LLMDataDistWorker, this worker now only support kv_producer and kv_consumer, but receiving {vllm_config.kv_transfer_config.kv_role}"
|
||||
)
|
||||
|
||||
# linked_cluster record the cluster that already build the connection its format should be {"cluster_id": "comm_name"}
|
||||
self.linked_cluster: dict[Any, Any] = {}
|
||||
self.prefill_device_list: list[tuple[int, int]] = []
|
||||
self.decode_device_list: list[tuple[int, int]] = []
|
||||
global_rank_table = self.read_offline_rank_table()
|
||||
self.local_agent_metadata = self.read_agent_metadata(global_rank_table)
|
||||
self.llm_datadist = LLMDataDist(self.llm_datadist_role,
|
||||
self.local_agent_metadata.cluster_id)
|
||||
self.init_llm_datadist()
|
||||
self.finished_reqs: set[str] = set()
|
||||
self.soc_info = get_ascend_soc_version()
|
||||
# Set hccl deterministic for model execute
|
||||
os.environ["HCCL_DETERMINISTIC"] = "true"
|
||||
self.done_receiving_counts: defaultdict[str,
|
||||
set[int]] = defaultdict(set)
|
||||
|
||||
def listen_for_agent_metadata_req(self, event: threading.Event):
|
||||
assert self.local_agent_metadata is not None
|
||||
port = envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT + self.local_dp_rank * self.tp_size + self.tp_rank if self.local_dp_rank is not None else envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT + self.tp_size + self.tp_rank
|
||||
url = f"tcp://{envs_ascend.VLLM_ASCEND_LLMDD_RPC_IP}:{port}"
|
||||
msg_encoder = msgspec.msgpack.Encoder()
|
||||
msg_decoder = msgspec.msgpack.Decoder()
|
||||
msg_to_send = msg_encoder.encode(self.local_agent_metadata)
|
||||
logger.debug(f"Start to listen to address: {url}")
|
||||
logger.debug(
|
||||
f"The local agent metadata have {len(msg_to_send)} bytes here")
|
||||
logger.info(
|
||||
f"LLMDataDistCMgrConnectorWorker: Cluster {self.local_agent_metadata.cluster_id} start to listen request from peers"
|
||||
)
|
||||
with zmq_ctx(zmq.ROUTER, url) as sock: # type: ignore[attr-defined]
|
||||
event.set()
|
||||
while True:
|
||||
identity, _, msg = sock.recv_multipart()
|
||||
event_msg, decode_msg = msg_decoder.decode(msg)
|
||||
event_msg = LLMDataDistCMgrEvent(event_msg)
|
||||
if event_msg == LLMDataDistCMgrEvent.ReqForMetadata:
|
||||
if "cluster_id" in decode_msg:
|
||||
decode_msg = LLMDataDistCMgrAgentMetadata(**decode_msg)
|
||||
logger.info(
|
||||
f"LLMDataDistCMgrConnectorWorker: Receive message from cluster {decode_msg.cluster_id}"
|
||||
)
|
||||
sock.send_multipart((identity, b"", msg_to_send))
|
||||
self.add_remote_agent(decode_msg)
|
||||
else:
|
||||
logger.warning(
|
||||
f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data {decode_msg}"
|
||||
)
|
||||
elif event_msg == LLMDataDistCMgrEvent.ReqForFinished:
|
||||
finished_req_id = decode_msg[0]
|
||||
decode_tp_rank = decode_msg[1]
|
||||
decode_tp_size = decode_msg[2]
|
||||
with self.thread_lock:
|
||||
if self._increment_task_count(finished_req_id,
|
||||
decode_tp_rank,
|
||||
decode_tp_size):
|
||||
logger.debug(
|
||||
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
|
||||
)
|
||||
self.finished_reqs.add(finished_req_id)
|
||||
sock.send_multipart(
|
||||
(identity, b"", b"receiving decode finished"))
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event {event_msg} from remote !"
|
||||
)
|
||||
|
||||
def _increment_task_count(self, request_id: str, tp_rank: int,
|
||||
decode_tp_size: int):
|
||||
if request_id not in self.done_receiving_counts:
|
||||
self.done_receiving_counts[request_id] = set()
|
||||
if tp_rank in self.done_receiving_counts[request_id]:
|
||||
logger.warning(
|
||||
f"Received duplicate done signal for request {request_id} "
|
||||
f"from tp rank {tp_rank}. Ignoring.")
|
||||
return False
|
||||
self.done_receiving_counts[request_id].add(tp_rank)
|
||||
if len(self.done_receiving_counts[request_id]) == decode_tp_size:
|
||||
self.done_receiving_counts.pop(request_id)
|
||||
logger.info("All transfers completed for request: "
|
||||
f"{request_id}. Total ranks: "
|
||||
f"{decode_tp_size}.")
|
||||
return True
|
||||
return False
|
||||
|
||||
def init_llm_datadist(self):
|
||||
assert self.local_agent_metadata is not None
|
||||
llm_config = LLMConfig()
|
||||
llm_config.device_id = self.local_rank
|
||||
llm_config.sync_kv_timeout = 20000
|
||||
llm_config.enable_switch_role = True
|
||||
llm_config.enable_cache_manager = True
|
||||
llm_config.enable_remote_cache_accessible = True
|
||||
llm_config_options = llm_config.generate_options()
|
||||
self.llm_datadist.init(llm_config_options)
|
||||
self.cache_manager = self.llm_datadist.cache_manager
|
||||
logger.info(
|
||||
f"Done initialize llm_datadist in rank {self.rank}, local rank {self.local_rank}, cluster id {self.local_agent_metadata.cluster_id}"
|
||||
)
|
||||
|
||||
def read_offline_rank_table(self):
|
||||
assert (
|
||||
envs_ascend.DISAGGREGATED_PREFILL_RANK_TABLE_PATH
|
||||
), "Please set path of rank_table to env variable DISAGGREGATED_PREFILL_RANK_TABLE_PATH"
|
||||
rank_table_path = envs_ascend.DISAGGREGATED_PREFILL_RANK_TABLE_PATH
|
||||
with open(rank_table_path, "r", encoding="utf-8") as f:
|
||||
global_rank_table = json.load(f)
|
||||
decode_device_list = global_rank_table["decode_device_list"]
|
||||
for decode_device in decode_device_list:
|
||||
server_id = decode_device["server_id"]
|
||||
device_id = decode_device["device_id"]
|
||||
self.decode_device_list.append((server_id, device_id))
|
||||
prefill_device_list = global_rank_table["prefill_device_list"]
|
||||
for prefill_device in prefill_device_list:
|
||||
server_id = prefill_device["server_id"]
|
||||
device_id = prefill_device["device_id"]
|
||||
self.prefill_device_list.append((server_id, device_id))
|
||||
|
||||
# global_rank_table = json.dumps(global_rank_table)
|
||||
return global_rank_table
|
||||
|
||||
@staticmethod
|
||||
def _get_visible_devices() -> Callable[[str], bool]:
|
||||
"""
|
||||
Return a test function that check if the given device ID is visible.
|
||||
i.e. ASCEND_RT_VISIBLE_DEVICES is not set or contains the device_id.
|
||||
"""
|
||||
visible_devices = os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "")
|
||||
if not visible_devices:
|
||||
return lambda device_id: True
|
||||
visible_device_list = visible_devices.split(",")
|
||||
return lambda device_id: device_id in visible_device_list
|
||||
|
||||
def read_agent_metadata(self, global_rank_table):
|
||||
device_filter = LLMDataDistCMgrConnectorWorker._get_visible_devices()
|
||||
devices_type_list = []
|
||||
agent_metadata = None
|
||||
if self.llm_datadist_role == LLMRole.PROMPT:
|
||||
devices_type_list.append("prefill_device_list")
|
||||
elif self.llm_datadist_role == LLMRole.DECODER:
|
||||
devices_type_list.append("decode_device_list")
|
||||
else:
|
||||
devices_type_list.append("prefill_device_list")
|
||||
devices_type_list.append("decode_device_list")
|
||||
for device_type in devices_type_list:
|
||||
device_list = global_rank_table[device_type]
|
||||
device_list = [
|
||||
d for d in device_list if d.get("server_id") == self.local_ip
|
||||
and device_filter(d.get("device_id", ""))
|
||||
]
|
||||
if len(device_list) <= self.tp_rank:
|
||||
continue
|
||||
device_info = device_list[self.tp_rank]
|
||||
super_pod_id_ = device_info.get("super_pod_id", None)
|
||||
server_id_ = device_info["server_id"]
|
||||
device_id_ = device_info["device_id"]
|
||||
device_ip_ = device_info["device_ip"]
|
||||
super_device_id_ = device_info.get("super_device_id", None)
|
||||
cluster_id_ = int(device_info["cluster_id"])
|
||||
agent_metadata = LLMDataDistCMgrAgentMetadata(
|
||||
super_pod_id=super_pod_id_,
|
||||
server_id=server_id_,
|
||||
device_id=device_id_,
|
||||
device_ip=device_ip_,
|
||||
super_device_id=super_device_id_,
|
||||
cluster_id=cluster_id_,
|
||||
)
|
||||
assert agent_metadata is not None, f"Can't read the target server_id {self.local_ip} and device_rank {self.rank} from rank table"
|
||||
return agent_metadata
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, Tuple[torch.Tensor]]):
|
||||
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
|
||||
first_kv_cache = first_kv_cache_tuple[0]
|
||||
assert len(first_kv_cache_tuple) > 1
|
||||
assert self.local_agent_metadata is not None
|
||||
kv_cache_dtype = first_kv_cache.dtype
|
||||
self.use_mla: bool = first_kv_cache_tuple[0].size(
|
||||
-1) != first_kv_cache_tuple[1].size(-1)
|
||||
# MLA case. [2 (k_normed, k_pe), num_blocks, ...]
|
||||
# MHA case. [2 (k and v), num_blocks, ...]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 3 # [block_size, latent_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
|
||||
self.block_len = math.prod(block_shape)
|
||||
self.cache_addr: list[int] = []
|
||||
alignment = 2 * 1024 * 1024
|
||||
if self.use_mla:
|
||||
cache_k_normed_addr_list = []
|
||||
cache_k_pe_addr_list = []
|
||||
k_normed = None
|
||||
k_pe = None
|
||||
for cache_or_caches in kv_caches.values():
|
||||
assert len(cache_or_caches) > 1
|
||||
k_normed, k_pe = cache_or_caches[0], cache_or_caches[1]
|
||||
cache_k_normed_addr_list.append(k_normed.data_ptr())
|
||||
cache_k_pe_addr_list.append(k_pe.data_ptr())
|
||||
self.cache_addr = (cache_k_normed_addr_list, cache_k_pe_addr_list)
|
||||
|
||||
cache_desc_k_normed = CacheDesc(
|
||||
len(self.cache_addr[0]), [*k_normed.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
cache_desc_k_pe = CacheDesc(
|
||||
len(self.cache_addr[1]), [*k_pe.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
cache_key_k_normed = BlocksCacheKey(cluster_id=int(
|
||||
self.local_agent_metadata.cluster_id),
|
||||
model_id=0)
|
||||
cache_key_k_pe = BlocksCacheKey(cluster_id=int(
|
||||
self.local_agent_metadata.cluster_id),
|
||||
model_id=1)
|
||||
self.cache_desc = (cache_desc_k_normed, cache_desc_k_pe)
|
||||
self.cache_key = (cache_key_k_normed, cache_key_k_pe)
|
||||
try:
|
||||
cache_k_normed = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc[0], self.cache_addr[0], self.cache_key[0])
|
||||
cache_k_pe = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc[1], self.cache_addr[1], self.cache_key[1])
|
||||
self.cache = (cache_k_normed, cache_k_pe)
|
||||
logger.info("LLMDataDistWorker: End of register Paged Cache.")
|
||||
except (TypeError, ValueError):
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
|
||||
)
|
||||
else:
|
||||
for cache_or_caches in kv_caches.values():
|
||||
for cache in cache_or_caches:
|
||||
base_addr = cache.data_ptr()
|
||||
assert base_addr % alignment == 0, "The address of the registered kv cache should be aligned to 2M"
|
||||
self.cache_addr.append(base_addr)
|
||||
# register paged kv cache into the llm_cache manager
|
||||
self.cache_desc = CacheDesc(
|
||||
len(self.cache_addr), [*cache.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
self.cache_key = BlocksCacheKey(
|
||||
cluster_id=int(self.local_agent_metadata.cluster_id))
|
||||
logger.info(
|
||||
f"num of cache: {len(self.cache_addr)}, size of cache: {[*cache.shape]}, real size of cache: {first_kv_cache.shape}"
|
||||
)
|
||||
try:
|
||||
self.cache = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc, self.cache_addr, self.cache_key)
|
||||
logger.info(
|
||||
"LLMDataDistCMgrConnectorWorker: End of register Paged Cache."
|
||||
)
|
||||
except (TypeError, ValueError):
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
|
||||
)
|
||||
self.ready_event = threading.Event()
|
||||
self.metadata_agent_listener_t = threading.Thread(
|
||||
target=self.listen_for_agent_metadata_req,
|
||||
args=(self.ready_event, ),
|
||||
daemon=True,
|
||||
name="metadata_agent_listener")
|
||||
self.metadata_agent_listener_t.start()
|
||||
self.ready_event.wait()
|
||||
|
||||
def start_load_kv(self, metadata: LLMDataDistCMgrConnectorMetadata):
|
||||
futures = []
|
||||
for req_id, meta in metadata.requests.items():
|
||||
logger.debug(f"Start to transmit {req_id}")
|
||||
future = self.executor.submit(
|
||||
self._read_blocks,
|
||||
local_block_ids=meta.local_block_ids,
|
||||
remote_block_ids=meta.remote_block_ids,
|
||||
remote_ip=meta.remote_host,
|
||||
remote_port=int(meta.remote_port),
|
||||
remote_engine_id=meta.engine_id,
|
||||
request_id=req_id,
|
||||
remote_tp_size=meta.remote_tp_size,
|
||||
)
|
||||
futures.append(future)
|
||||
|
||||
def handle_exception(future):
|
||||
if future.exception():
|
||||
logger.error(f"KV transfer task failed: {future.exception()}")
|
||||
|
||||
for future in futures:
|
||||
future.add_done_callback(handle_exception)
|
||||
|
||||
def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int:
|
||||
assert self.local_agent_metadata is not None
|
||||
remote_cluster_id = metadata.cluster_id
|
||||
if remote_cluster_id in self.linked_cluster:
|
||||
logger.debug(
|
||||
f"LLMDataDistCMgrConnectorWorker: remote cluster_id: {metadata.cluster_id} already linked with this server, skip the connection"
|
||||
)
|
||||
return remote_cluster_id
|
||||
remote_super_pod_id = metadata.super_pod_id
|
||||
remote_server_id = metadata.server_id
|
||||
is_same_server = remote_server_id == self.local_agent_metadata.server_id
|
||||
is_same_pod = remote_super_pod_id == self.local_agent_metadata.super_pod_id
|
||||
if self.llm_datadist_role == LLMRole.PROMPT:
|
||||
prefill_metadata = self.local_agent_metadata
|
||||
decode_metadata = metadata
|
||||
else:
|
||||
prefill_metadata = metadata
|
||||
decode_metadata = self.local_agent_metadata
|
||||
comm_name = f"pd_comm_{prefill_metadata.device_ip}_{decode_metadata.device_ip}"
|
||||
cluster_rank_info = {
|
||||
prefill_metadata.cluster_id: 0,
|
||||
decode_metadata.cluster_id: 1
|
||||
}
|
||||
rank_table = {}
|
||||
rank_table["version"] = "1.2"
|
||||
rank_table["server_count"] = "1" if is_same_server else "2"
|
||||
rank_table["status"] = "completed"
|
||||
|
||||
# generate server_list for rank table
|
||||
rank_table["server_list"] = [] # type: ignore[assignment]
|
||||
decode_server_device_info = None
|
||||
prefill_server_device_info = {
|
||||
"device": [{
|
||||
k: v
|
||||
for k, v in [(
|
||||
"device_id", prefill_metadata.device_id
|
||||
), ("device_ip", prefill_metadata.device_ip
|
||||
), ("super_device_id",
|
||||
prefill_metadata.super_device_id), ("rank_id", "0")]
|
||||
if v is not None
|
||||
}],
|
||||
"server_id":
|
||||
prefill_metadata.server_id
|
||||
}
|
||||
if is_same_server:
|
||||
prefill_server_device_info["device"].append( # type: ignore[attr-defined]
|
||||
{
|
||||
k: v
|
||||
for k, v in [(
|
||||
"device_id", decode_metadata.device_id
|
||||
), ("device_ip", decode_metadata.device_ip
|
||||
), ("super_device_id",
|
||||
decode_metadata.super_device_id), ("rank_id", "1")]
|
||||
if v is not None
|
||||
})
|
||||
else:
|
||||
decode_server_device_info = {
|
||||
"device": [{
|
||||
k: v
|
||||
for k, v in [(
|
||||
"device_id", decode_metadata.device_id
|
||||
), ("device_ip", decode_metadata.device_ip
|
||||
), ("super_device_id",
|
||||
decode_metadata.super_device_id), ("rank_id", "1")]
|
||||
if v is not None
|
||||
}],
|
||||
"server_id":
|
||||
decode_metadata.server_id
|
||||
}
|
||||
rank_table["server_list"].append( # type: ignore[attr-defined]
|
||||
prefill_server_device_info)
|
||||
if decode_server_device_info is not None:
|
||||
rank_table["server_list"].append( # type: ignore[attr-defined]
|
||||
decode_server_device_info)
|
||||
|
||||
if self.soc_info == AscendSocVersion.A3:
|
||||
# generate super_pod_list for rank table
|
||||
super_pod_list = []
|
||||
prefill_super_pod_info = {
|
||||
"super_pod_id": prefill_metadata.super_pod_id,
|
||||
"server_list": [{
|
||||
"server_id": prefill_metadata.server_id
|
||||
}],
|
||||
}
|
||||
if is_same_pod and not is_same_server:
|
||||
prefill_super_pod_info[
|
||||
"server_list"].append( # type: ignore[attr-defined]
|
||||
{"server_id": decode_metadata.server_id})
|
||||
super_pod_list.append(prefill_super_pod_info)
|
||||
if not is_same_pod:
|
||||
decode_super_pod_id = {
|
||||
"super_pod_id": decode_metadata.super_pod_id,
|
||||
"server_list": [{
|
||||
"server_id": decode_metadata.server_id
|
||||
}],
|
||||
}
|
||||
super_pod_list.append(decode_super_pod_id)
|
||||
rank_table[
|
||||
"super_pod_list"] = super_pod_list # type: ignore[assignment]
|
||||
logger.info(
|
||||
f"LLMDataDistCMgrConnectorWorker: try link with remote, comm id: {comm_name}"
|
||||
)
|
||||
logger.info(f"rank table \n{rank_table}")
|
||||
logger.info(f"comm name: {comm_name}")
|
||||
logger.info(f"cluster rank info: {cluster_rank_info}")
|
||||
comm_id = self.llm_datadist.link(comm_name, cluster_rank_info,
|
||||
json.dumps(rank_table))
|
||||
while True:
|
||||
ret = self.llm_datadist.query_register_mem_status(comm_id=comm_id)
|
||||
if ret == llm_datadist.RegisterMemStatus.OK:
|
||||
logger.info(
|
||||
f"LLMDataDistCMgrConnectorWorker: Linking success, comm id: {comm_id}"
|
||||
)
|
||||
break
|
||||
elif ret == llm_datadist.RegisterMemStatus.FAILED:
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Linking failed, comm id: {comm_id}"
|
||||
)
|
||||
time.sleep(1)
|
||||
logger.info("Checking query_register_mem_status again")
|
||||
self.linked_cluster.update({remote_cluster_id: comm_id})
|
||||
logger.info(f"cached linked cluster: {self.linked_cluster}")
|
||||
logger.info(
|
||||
f"Successfully build link with cluster id {remote_cluster_id} with cluster name {comm_name} !"
|
||||
)
|
||||
return remote_cluster_id
|
||||
|
||||
def remove_remote_agent(self, cluster_id: int):
|
||||
if cluster_id not in self.linked_cluster:
|
||||
logger.warning(
|
||||
f"LLMDataDistCMgrConnectorWorker: Warning! Can't remove remote client with cluster id {cluster_id} for its not exist in linked_cluster list"
|
||||
)
|
||||
comm_id = self.linked_cluster[cluster_id]
|
||||
try:
|
||||
self.llm_datadist.unlink(comm_id)
|
||||
self.linked_cluster.pop(cluster_id)
|
||||
except LLMException:
|
||||
logger.error(
|
||||
f"Try to remove remote client with cluster id {cluster_id} failed!, program won't terminate, but please carefully check your environment"
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully remove remote client with cluster id {cluster_id} !"
|
||||
)
|
||||
|
||||
def connect_to_remote_agent(self, host: str, port: int) -> int:
|
||||
url = f"tcp://{host}:{port}"
|
||||
logger.debug(f"Querying metadata from url: {url}")
|
||||
msg_encoder = msgspec.msgpack.Encoder()
|
||||
msg_send = msg_encoder.encode(
|
||||
[LLMDataDistCMgrEvent.ReqForMetadata, self.local_agent_metadata])
|
||||
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
|
||||
logger.info("Try request remote metadata from socket......")
|
||||
sock.send(msg_send)
|
||||
metadata_bytes = sock.recv()
|
||||
decoder = msgspec.msgpack.Decoder()
|
||||
metadata = decoder.decode(metadata_bytes)
|
||||
metadata = LLMDataDistCMgrAgentMetadata(**metadata)
|
||||
logger.info(f"recving metadata: {metadata}")
|
||||
cluster_id = self.add_remote_agent(metadata)
|
||||
return cluster_id
|
||||
|
||||
def send_finish_to_remote(self, host: str, port: int, request_id):
|
||||
url = f"tcp://{host}:{port}"
|
||||
logger.debug(f"Sending finished to remote: {url}")
|
||||
msg_encoder = msgspec.msgpack.Encoder()
|
||||
msg_send = msg_encoder.encode([
|
||||
LLMDataDistCMgrEvent.ReqForFinished,
|
||||
[request_id, self.tp_rank, self.tp_size]
|
||||
])
|
||||
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
|
||||
try:
|
||||
sock.send(msg_send)
|
||||
logger.debug(
|
||||
f"Request id {request_id} finished message send to remote {url}"
|
||||
)
|
||||
_ = sock.recv()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to send reqest_id {request_id} to prefill: {e}")
|
||||
|
||||
def _read_blocks(
|
||||
self,
|
||||
local_block_ids: list[int],
|
||||
remote_block_ids: list[int],
|
||||
remote_ip: str,
|
||||
remote_port: int,
|
||||
remote_engine_id: str,
|
||||
request_id: str,
|
||||
remote_tp_size: str,
|
||||
):
|
||||
# if remote_ip not in self.linked_cluster:
|
||||
tp_offset = self.tp_rank % int(remote_tp_size)
|
||||
remote_cluster_id = self.connect_to_remote_agent(
|
||||
remote_ip, remote_port + tp_offset)
|
||||
num_local_blocks = len(local_block_ids)
|
||||
if num_local_blocks == 0:
|
||||
return
|
||||
num_remote_blocks = len(remote_block_ids)
|
||||
assert num_local_blocks <= num_remote_blocks
|
||||
if num_local_blocks < num_remote_blocks:
|
||||
remote_block_ids = remote_block_ids[-num_local_blocks:]
|
||||
|
||||
logger.info(f"remote cluster id is: {remote_cluster_id}")
|
||||
if self.use_mla:
|
||||
remote_cache_key_k_normed = BlocksCacheKey(
|
||||
cluster_id=remote_cluster_id, model_id=0)
|
||||
remote_cache_key_k_pe = BlocksCacheKey(
|
||||
cluster_id=remote_cluster_id, model_id=1)
|
||||
logger.info("Try pull blocks from remote server")
|
||||
try:
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key_k_normed,
|
||||
self.cache[0], # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key_k_pe,
|
||||
self.cache[1], # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
except (TypeError, ValueError):
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
|
||||
)
|
||||
except LLMException:
|
||||
raise RuntimeError(
|
||||
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
||||
)
|
||||
else:
|
||||
remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id)
|
||||
logger.info("Try pull blocks from remote server")
|
||||
try:
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key,
|
||||
self.cache, # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
except (TypeError, ValueError):
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
|
||||
)
|
||||
except LLMException:
|
||||
raise RuntimeError(
|
||||
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
||||
)
|
||||
self.send_finish_to_remote(remote_ip, remote_port, request_id)
|
||||
with self.thread_lock:
|
||||
self.finished_reqs.add(request_id)
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
"""Get the finished recving and sending requuests."""
|
||||
import copy
|
||||
with self.thread_lock:
|
||||
req_ids_to_ret = copy.deepcopy(self.finished_reqs)
|
||||
self.finished_reqs.clear()
|
||||
if self.llm_datadist_role == LLMRole.PROMPT:
|
||||
return req_ids_to_ret, None
|
||||
else:
|
||||
return None, req_ids_to_ret
|
||||
|
||||
|
||||
# adopt this from https://github.com/vllm-project/vllm/blob/main/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
|
||||
@contextlib.contextmanager
|
||||
def zmq_ctx(socket_type: Any,
|
||||
addr: str) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
|
||||
"""Context manager for a ZMQ socket"""
|
||||
|
||||
ctx: Optional[zmq.Context] = None # type: ignore[name-defined]
|
||||
try:
|
||||
ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
|
||||
if socket_type == zmq.ROUTER: # type: ignore[attr-defined]
|
||||
socket = ctx.socket(zmq.ROUTER) # type: ignore[attr-defined]
|
||||
socket.bind(addr)
|
||||
elif socket_type == zmq.REQ: # type: ignore[attr-defined]
|
||||
socket = ctx.socket(zmq.REQ) # type: ignore[attr-defined]
|
||||
socket.connect(addr)
|
||||
else:
|
||||
raise ValueError(f"Unexpected socket type: {socket_type}")
|
||||
|
||||
yield socket
|
||||
finally:
|
||||
if ctx is not None:
|
||||
ctx.destroy(linger=0)
|
||||
556
vllm_ascend/distributed/moe_comm_method.py
Normal file
556
vllm_ascend/distributed/moe_comm_method.py
Normal file
@@ -0,0 +1,556 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_ascend.distributed.communication_op import \
|
||||
data_parallel_reduce_scatter
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
|
||||
|
||||
|
||||
class MoECommMethod(ABC):
|
||||
"""Base class for MoE communication methods."""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
self.moe_config = moe_config
|
||||
|
||||
@abstractmethod
|
||||
def prepare(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Prepare the MoE communication method.
|
||||
|
||||
This method is called before quant_method.apply to prepare the
|
||||
communication method. It can be used to initialize any necessary
|
||||
resources or configurations.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""Finalize the MoE communication method.
|
||||
|
||||
This method is called after quant_method.apply to finalize the
|
||||
communication method. It can be used to clean up any resources or
|
||||
configurations.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def permute(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
num_experts: int,
|
||||
apply_a8_quantization: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
||||
"""Pre-process before MLP.
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): Tensor of shape (num_tokens, hidden_size)
|
||||
topk_ids (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
|
||||
topk_weights (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
|
||||
expert_map (torch.Tensor): Tensor of shape (global_num_experts, )
|
||||
Mapping from global expert IDs to local expert IDs.
|
||||
num_experts (int): Number of local experts (experts on this device).
|
||||
apply_a8_quantization (bool): Whether to apply A8 quantization (W4A8 and W8A8).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor, int]: Return a tuple containing:
|
||||
- permuted_hidden_states (torch.Tensor): Tensor of shape
|
||||
(num_tokens * top_k_num, hidden_size) after permuting
|
||||
hidden_states based on topk_ids.
|
||||
- expert_tokens (torch.Tensor): Tensor of shape (num_experts, )
|
||||
Number of tokens assigned to each expert.
|
||||
- dynamic_scale (torch.Tensor, optional): Tensor of shape (num_experts, )
|
||||
Dynamic scale for each expert, used for quantization.
|
||||
- group_list_type (int): Type of group list, 0 for `cumsum`
|
||||
and 1 for `count`. This is mainly for `npu_grouped_matmul`
|
||||
to determine how to handle the output.
|
||||
Raises:
|
||||
NotImplementedError: If the method is not implemented in the subclass.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unpermute(self, mlp_output: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
"""Post-process after MLP.
|
||||
|
||||
Args:
|
||||
mlp_output (torch.Tensor): Tensor of shape
|
||||
(num_tokens * top_k_num, hidden_size) after MLP.
|
||||
hidden_states (torch.Tensor): Tensor of shape
|
||||
(num_tokens, hidden_size) to be updated with the final output.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class AllGatherCommImpl(MoECommMethod):
|
||||
"""This implementation is the same as NativeAllGatherCommImpl,
|
||||
but uses NPU-specific ops for better performance.
|
||||
|
||||
This implementation should be compatible with all scenarios, and
|
||||
thus it is the default implementation for MoE communication methods.
|
||||
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
|
||||
and `torch_npu.npu_moe_token_unpermute` for post-processing
|
||||
to handle the token-to-expert mapping and communication efficiently.
|
||||
|
||||
NOTE(Yizhou): TBH, it is really weird that we were supposed to use
|
||||
`torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing`
|
||||
or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute`
|
||||
for pre-processing and post-processing, respectively.
|
||||
But `npu_moe_finalize_routing` will lead to accuracy issues so we have to
|
||||
use `torch_npu.npu_moe_token_unpermute` instead.
|
||||
This is a workaround and should be removed after the issue is fixed.
|
||||
"""
|
||||
|
||||
def prepare(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""When DP size > 1, pad the hidden states and router logits for communication."""
|
||||
if self.moe_config.dp_size > 1:
|
||||
forward_context = get_forward_context()
|
||||
max_tokens_across_dp = forward_context.max_tokens_across_dp
|
||||
|
||||
self.num_tokens = hidden_states.shape[0]
|
||||
pad_size = max_tokens_across_dp - self.num_tokens
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
hidden_states = self.moe_config.dp_group.all_gather(
|
||||
hidden_states, 0)
|
||||
router_logits = self.moe_config.dp_group.all_gather(
|
||||
router_logits, 0)
|
||||
|
||||
return hidden_states, router_logits
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""When DP size > 1, reduce-scatter the hidden states to get the final output.
|
||||
|
||||
When TP size > 1, all-reduce the hidden states to get the final output.
|
||||
"""
|
||||
if self.moe_config.dp_size > 1:
|
||||
hidden_states = data_parallel_reduce_scatter(hidden_states, dim=0)
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
|
||||
if reduce_results and (self.moe_config.tp_size > 1
|
||||
or self.moe_config.ep_size > 1):
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def permute(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor, # noqa: F841
|
||||
num_experts: int,
|
||||
apply_a8_quantization: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
|
||||
self.topk_weights = topk_weights
|
||||
self.topk_ids = topk_ids
|
||||
|
||||
first_expert_idx = 0
|
||||
if expert_map is not None:
|
||||
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
||||
# So we need to filter out invalid tokens by zeroing their weights.
|
||||
# This is a workaround and should be removed after the issue is fixed
|
||||
mask = expert_map[topk_ids] != -1
|
||||
# NOTE: This is equivalent to self.topk_weights[~mask] = 0.0,
|
||||
# but ~mask will dispatch to aclnnNonzeroV2, which is not supported in ACL Graph
|
||||
self.topk_weights = torch.where(mask, topk_weights, 0.0)
|
||||
|
||||
first_expert_idx = self.moe_config.ep_rank * num_experts
|
||||
last_expert_idx = first_expert_idx + num_experts
|
||||
|
||||
permuted_hidden_states, expanded_row_idx, expert_tokens, _ = (
|
||||
torch_npu.npu_moe_init_routing_v2(
|
||||
hidden_states,
|
||||
topk_ids,
|
||||
active_num=num_tokens * self.moe_config.experts_per_token,
|
||||
expert_num=self.moe_config.num_experts,
|
||||
expert_tokens_num_type=1, # Only support `count` mode now
|
||||
expert_tokens_num_flag=True, # Output `expert_tokens`
|
||||
active_expert_range=[first_expert_idx, last_expert_idx],
|
||||
quant_mode=-1,
|
||||
))
|
||||
self.expanded_row_idx = expanded_row_idx
|
||||
permuted_hidden_states = permuted_hidden_states
|
||||
|
||||
group_list_type = 1 # `count` mode
|
||||
|
||||
return permuted_hidden_states, expert_tokens, None, group_list_type
|
||||
|
||||
def unpermute(self, mlp_output: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
hidden_states[:] = torch_npu.npu_moe_token_unpermute(
|
||||
permuted_tokens=mlp_output,
|
||||
sorted_indices=self.expanded_row_idx,
|
||||
probs=self.topk_weights)
|
||||
|
||||
|
||||
class NativeAllGatherCommImpl(AllGatherCommImpl):
|
||||
"""This implementation should be compatible with all scenarios.
|
||||
|
||||
Note that this implementation purely consists of native PyTorch ops
|
||||
and does not use any NPU-specific ops. So the performance may not be optimal.
|
||||
But it is a good fallback for scenarios where NPU-specific ops are not available.
|
||||
"""
|
||||
|
||||
def permute(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
num_experts: int,
|
||||
apply_a8_quantization: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
|
||||
# Generate token indices and flatten
|
||||
token_indices = torch.arange(num_tokens,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.int64)
|
||||
token_indices = (token_indices.unsqueeze(1).expand(
|
||||
-1, self.moe_config.experts_per_token).reshape(-1))
|
||||
|
||||
# Flatten token-to-expert mappings and map to local experts
|
||||
weights_flat = topk_weights.view(-1)
|
||||
experts_flat = topk_ids.view(-1)
|
||||
local_experts_flat = (expert_map[experts_flat]
|
||||
if expert_map is not None else experts_flat)
|
||||
|
||||
# Filter valid token-expert pairs
|
||||
mask = local_experts_flat != -1
|
||||
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
||||
# So we need to filter out invalid tokens by zeroing their weights.
|
||||
# This is a workaround and should be removed after the issue is fixed
|
||||
filtered_weights = torch.where(mask, weights_flat,
|
||||
torch.zeros_like(weights_flat)).to(
|
||||
topk_weights.dtype)
|
||||
filtered_experts = torch.where(
|
||||
mask,
|
||||
local_experts_flat,
|
||||
torch.full_like(local_experts_flat, num_experts),
|
||||
).to(topk_ids.dtype)
|
||||
|
||||
# Sort by local expert IDs
|
||||
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
|
||||
self.sorted_token_indices = token_indices[sort_indices]
|
||||
self.sorted_weights = filtered_weights[sort_indices]
|
||||
|
||||
# Compute token counts with minlength of num_experts
|
||||
# This is equivalent to but faster than:
|
||||
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
|
||||
token_counts = torch.zeros(num_experts + 1,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.int64)
|
||||
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
|
||||
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
|
||||
expert_tokens = token_counts[:num_experts]
|
||||
|
||||
# Rearrange hidden_states
|
||||
permuted_hidden_states = hidden_states[self.sorted_token_indices]
|
||||
|
||||
group_list_type = 1 # `count` mode
|
||||
|
||||
return permuted_hidden_states, expert_tokens, None, group_list_type
|
||||
|
||||
def unpermute(self, mlp_output: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
mlp_output = mlp_output * self.sorted_weights.unsqueeze(1)
|
||||
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
final_hidden_states.index_add_(0, self.sorted_token_indices,
|
||||
mlp_output)
|
||||
|
||||
hidden_states[:] = final_hidden_states
|
||||
|
||||
|
||||
class MC2CommImpl(MoECommMethod):
|
||||
"""This implementation is for the scenarios listed below:
|
||||
1. `enable_expert_parallel=True`.
|
||||
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
|
||||
3. `enable_expert_parallel=False` is not supported.
|
||||
|
||||
This implementation uses the MC2 communication method, which is optimized for
|
||||
Communication and Computation parallelism on Ascend devices.
|
||||
"""
|
||||
|
||||
def __init__(self, moe_config: Optional[FusedMoEConfig]):
|
||||
super().__init__(moe_config)
|
||||
|
||||
# NOTE: We do not need to use mc2_group's rank and world size
|
||||
# because ep_group and mc2_group basically have the same init params.
|
||||
# We only init another group because of the restriction of MC2:
|
||||
# "No other groups can be used in the same process as the MC2 group."
|
||||
self.mc2_comm_name = get_mc2_group().device_group._get_backend(
|
||||
torch.device("npu")).get_hccl_comm_name(self.moe_config.ep_rank)
|
||||
|
||||
# Feature flags
|
||||
self.enable_dispatch_v2 = hasattr(torch_npu,
|
||||
"npu_moe_distribute_dispatch_v2")
|
||||
self.is_ascend_a3 = get_ascend_soc_version() == AscendSocVersion.A3
|
||||
self.need_extra_args = self.is_ascend_a3
|
||||
self._restore_tp_across_dp()
|
||||
|
||||
def _restore_tp_across_dp(self):
|
||||
# NOTE: Since vLLM flatten tp across dp, we need to restore the original
|
||||
# tp_size and tp_rank.
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
def prepare(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""The target_pad_length is calculated in forward_context, here we pad the
|
||||
hidden states and router logits. And if TP size > 1, we also need to split
|
||||
the tensors accordingly.
|
||||
"""
|
||||
self.num_tokens, _ = hidden_states.shape
|
||||
forward_context = get_forward_context()
|
||||
self.mc2_mask = forward_context.mc2_mask
|
||||
target_pad_length = forward_context.padded_num_tokens
|
||||
pad_size = target_pad_length - self.num_tokens
|
||||
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
if self.tp_size > 1:
|
||||
split_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_router_logits = torch.tensor_split(router_logits,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_mc2_mask = torch.tensor_split(self.mc2_mask,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
self.split_hidden_states = split_hidden_states
|
||||
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
self.mc2_mask = split_mc2_mask[self.tp_rank]
|
||||
|
||||
return hidden_states, router_logits
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""If TP size > 1, all-gather the hidden states to get the final output.
|
||||
|
||||
Also, unpad the hidden states if needed.
|
||||
"""
|
||||
if self.tp_size > 1:
|
||||
dist.all_gather(list(self.split_hidden_states), hidden_states,
|
||||
self.moe_config.tp_group.device_group)
|
||||
hidden_states = torch.cat(self.split_hidden_states, dim=0)
|
||||
|
||||
if self.num_tokens < hidden_states.shape[0]:
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
|
||||
return hidden_states
|
||||
|
||||
def permute(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
num_experts: int,
|
||||
apply_a8_quantization: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
||||
# Store tensors needed for post_process
|
||||
self.topk_ids = topk_ids
|
||||
self.topk_weights = topk_weights.to(torch.float32)
|
||||
|
||||
dispatch_kwargs = {
|
||||
"x": hidden_states,
|
||||
"expert_ids": self.topk_ids,
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": self.moe_config.num_experts,
|
||||
"global_bs": 0,
|
||||
"scales": None,
|
||||
"quant_mode": 2 if apply_a8_quantization else 0,
|
||||
"group_ep": self.mc2_comm_name,
|
||||
"ep_world_size": self.moe_config.ep_size,
|
||||
"ep_rank_id": self.moe_config.ep_rank,
|
||||
}
|
||||
|
||||
if self.need_extra_args:
|
||||
dispatch_kwargs.update({
|
||||
"group_tp": self.mc2_comm_name,
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
if self.is_ascend_a3 and self.enable_dispatch_v2:
|
||||
dispatch_kwargs.update({
|
||||
"x_active_mask": self.mc2_mask,
|
||||
})
|
||||
|
||||
dispatch = torch_npu.npu_moe_distribute_dispatch_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch
|
||||
|
||||
(
|
||||
permuted_hidden_states,
|
||||
dynamic_scale,
|
||||
self.assist_info_for_combine,
|
||||
expert_tokens,
|
||||
self.ep_recv_counts,
|
||||
self.tp_recv_counts,
|
||||
) = dispatch(**dispatch_kwargs)[:6]
|
||||
|
||||
group_list_type = 1
|
||||
|
||||
return permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type
|
||||
|
||||
def unpermute(self, mlp_output: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
combine_kwargs = {
|
||||
"expand_x": mlp_output,
|
||||
"expert_ids": self.topk_ids,
|
||||
"expert_scales": self.topk_weights,
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": self.moe_config.num_experts,
|
||||
"global_bs": 0,
|
||||
"ep_send_counts": self.ep_recv_counts,
|
||||
"group_ep": self.mc2_comm_name,
|
||||
"ep_world_size": self.moe_config.ep_size,
|
||||
"ep_rank_id": self.moe_config.ep_rank,
|
||||
}
|
||||
|
||||
if self.enable_dispatch_v2:
|
||||
combine_kwargs[
|
||||
"assist_info_for_combine"] = self.assist_info_for_combine
|
||||
else:
|
||||
combine_kwargs["expand_idx"] = self.assist_info_for_combine
|
||||
|
||||
if self.need_extra_args:
|
||||
combine_kwargs.update({
|
||||
"tp_send_counts": self.tp_recv_counts,
|
||||
"group_tp": self.mc2_comm_name,
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
if self.is_ascend_a3 and self.enable_dispatch_v2:
|
||||
combine_kwargs.update({
|
||||
"x_active_mask": self.mc2_mask,
|
||||
})
|
||||
|
||||
combine = torch_npu.npu_moe_distribute_combine_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine
|
||||
|
||||
hidden_states[:] = combine(**combine_kwargs)
|
||||
|
||||
|
||||
class AlltoAllCommImpl(MoECommMethod):
|
||||
"""This implementation is for the scenarios listed below:
|
||||
1. `enable_expert_parallel=True`.
|
||||
2. `npu_grouped_matmul` is available.
|
||||
|
||||
This implementation uses all-to-all communication to exchange tokens
|
||||
between data parallel ranks before and after the MLP computation. It should
|
||||
have better performance than AllGatherCommImpl when DP size > 1.
|
||||
"""
|
||||
|
||||
def __init__(self, moe_config: Optional[FusedMoEConfig]):
|
||||
super().__init__(moe_config)
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
|
||||
get_token_dispatcher
|
||||
self.token_dispatcher = get_token_dispatcher(
|
||||
"TokenDispatcherWithAll2AllV")
|
||||
self._restore_tp_across_dp()
|
||||
|
||||
def _restore_tp_across_dp(self):
|
||||
# NOTE: Since vLLM flatten tp across dp, we need to restore the original
|
||||
# tp_size and tp_rank.
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
def prepare(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
self.num_tokens, _ = hidden_states.shape
|
||||
pad_size = self.tp_size - self.num_tokens
|
||||
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
if self.tp_size > 1:
|
||||
split_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_router_logits = torch.tensor_split(router_logits,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
self.split_hidden_states = split_hidden_states
|
||||
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
|
||||
return hidden_states, router_logits
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""If TP size > 1, all-gather the hidden states to get the final output.
|
||||
|
||||
Also, unpad the hidden states if needed.
|
||||
"""
|
||||
if self.tp_size > 1:
|
||||
dist.all_gather(list(self.split_hidden_states), hidden_states,
|
||||
self.moe_config.tp_group.device_group)
|
||||
hidden_states = torch.cat(self.split_hidden_states, dim=0)
|
||||
|
||||
if self.num_tokens < hidden_states.shape[0]:
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
|
||||
return hidden_states
|
||||
|
||||
def permute(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
num_experts: int,
|
||||
apply_a8_quantization: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
||||
results = self.token_dispatcher.token_dispatch(
|
||||
hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
None,
|
||||
log2phy=None,
|
||||
with_quant=apply_a8_quantization)
|
||||
return results["hidden_states"], results["group_list"], results[
|
||||
"dynamic_scale"], results["group_list_type"]
|
||||
|
||||
def unpermute(self, mlp_output: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
hidden_states[:] = self.token_dispatcher.token_combine(mlp_output)
|
||||
1070
vllm_ascend/distributed/mooncake_connector.py
Normal file
1070
vllm_ascend/distributed/mooncake_connector.py
Normal file
File diff suppressed because it is too large
Load Diff
119
vllm_ascend/distributed/parallel_state.py
Normal file
119
vllm_ascend/distributed/parallel_state.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
|
||||
init_model_parallel_group)
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
|
||||
# Currently, mc2 op need their own group coordinator.
|
||||
_MC2: Optional[GroupCoordinator] = None
|
||||
_MLP_TP: Optional[GroupCoordinator] = None
|
||||
|
||||
_LMTP: Optional[GroupCoordinator] = None
|
||||
|
||||
|
||||
def get_mc2_group() -> GroupCoordinator:
|
||||
assert _MC2 is not None, ("mc2 group is not initialized")
|
||||
return _MC2
|
||||
|
||||
|
||||
def get_lmhead_tp_group() -> GroupCoordinator:
|
||||
assert _LMTP is not None, (
|
||||
"lm head tensor parallel group is not initialized")
|
||||
return _LMTP
|
||||
|
||||
|
||||
def get_mlp_tp_group() -> GroupCoordinator:
|
||||
assert _MLP_TP is not None, ("mlp group is not initialized")
|
||||
return _MLP_TP
|
||||
|
||||
|
||||
def model_parallel_initialized():
|
||||
return (_MC2 is not None)
|
||||
|
||||
|
||||
def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
||||
if model_parallel_initialized():
|
||||
return
|
||||
assert torch.distributed.is_initialized()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
backend = torch.distributed.get_backend(get_world_group().device_group)
|
||||
|
||||
# The layout of all ranks: ExternalDP * EP
|
||||
# ExternalDP is the data parallel group that is not part of the model,
|
||||
# every dp rank can generate independently (in verl integration).
|
||||
all_ranks = torch.arange(world_size).reshape(
|
||||
-1, parallel_config.data_parallel_size *
|
||||
parallel_config.tensor_parallel_size)
|
||||
global _MC2
|
||||
group_ranks = all_ranks.unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
|
||||
_MC2 = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="mc2")
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_MLP_OPTIMIZE:
|
||||
global _MLP_TP
|
||||
assert _MLP_TP is None, (
|
||||
"mlp tensor model parallel group is already initialized")
|
||||
|
||||
mlp_tp = parallel_config.data_parallel_size
|
||||
|
||||
all_ranks_mlp_head = torch.arange(world_size).reshape(
|
||||
-1, mlp_tp, parallel_config.pipeline_parallel_size, 1) # noqa
|
||||
group_ranks = all_ranks_mlp_head.view(-1, mlp_tp).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
|
||||
# message queue broadcaster is only used in tensor model parallel group
|
||||
_MLP_TP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="mlp_tp")
|
||||
|
||||
lmhead_tensor_parallel_size = get_ascend_config(
|
||||
).lmhead_tensor_parallel_size
|
||||
if lmhead_tensor_parallel_size is not None:
|
||||
group_ranks = []
|
||||
global _LMTP
|
||||
num_lmhead_tensor_parallel_groups: int = (world_size //
|
||||
lmhead_tensor_parallel_size)
|
||||
for i in range(num_lmhead_tensor_parallel_groups):
|
||||
ranks = list(
|
||||
range(i * lmhead_tensor_parallel_size,
|
||||
(i + 1) * lmhead_tensor_parallel_size))
|
||||
group_ranks.append(ranks)
|
||||
_LMTP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="lmheadtp")
|
||||
|
||||
|
||||
def get_mlp_tensor_model_parallel_world_size():
|
||||
"""Return world size for the tensor model parallel group."""
|
||||
return get_mlp_tp_group().world_size
|
||||
|
||||
|
||||
def get_mlp_tensor_model_parallel_rank():
|
||||
"""Return world size for the tensor model parallel group."""
|
||||
return get_mlp_tp_group().rank_in_group
|
||||
|
||||
|
||||
def destroy_ascend_model_parallel():
|
||||
global _MC2
|
||||
if _MC2:
|
||||
_MC2.destroy()
|
||||
_MC2 = None
|
||||
|
||||
global _MLP_TP
|
||||
if _MLP_TP:
|
||||
_MLP_TP.destroy()
|
||||
_MLP_TP = None
|
||||
|
||||
global _LMTP
|
||||
if _LMTP:
|
||||
_LMTP.destroy()
|
||||
_LMTP = None
|
||||
248
vllm_ascend/distributed/tensor_parallel.py
Normal file
248
vllm_ascend/distributed/tensor_parallel.py
Normal file
@@ -0,0 +1,248 @@
|
||||
# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Adapts from: Megatron/megatron/core/tensor_parallel/mappings.py.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
import torch
|
||||
|
||||
|
||||
def _gather_along_first_dim(input_, group, output_split_sizes=None):
|
||||
"""Gather tensors and concatenate along the first dimension.
|
||||
|
||||
Args:
|
||||
input_tensor (torch.Tensor):
|
||||
A tensor to be gathered.
|
||||
output_split_sizes (List[int], optional):
|
||||
A list specifying the sizes of the output splits along the first dimension.
|
||||
If None, equal splitting is assumed. Default: None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Gathered tensor.
|
||||
"""
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
dim_size = list(input_.size())
|
||||
if output_split_sizes is None:
|
||||
dim_size[0] = dim_size[0] * world_size
|
||||
|
||||
output = torch.empty(dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device())
|
||||
torch.distributed.all_gather_into_tensor(output,
|
||||
input_.contiguous(),
|
||||
group=group)
|
||||
else:
|
||||
dim_size[0] = sum(output_split_sizes)
|
||||
output = torch.empty(dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device())
|
||||
output_tensor_list = list(
|
||||
torch.split(output, output_split_sizes, dim=0))
|
||||
torch.distributed.all_gather(output_tensor_list, input_, group=group)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _gather_along_last_dim(input_, group):
|
||||
"""Gather tensors and concatenate along the last dimension."""
|
||||
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
dim_size = list(input_.size())
|
||||
dim_size[0] = dim_size[0] * world_size
|
||||
|
||||
output = torch.empty(dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device())
|
||||
torch.distributed.all_gather_into_tensor(output,
|
||||
input_.contiguous(),
|
||||
group=group)
|
||||
tensor_list = output.chunk(world_size, dim=0)
|
||||
output = torch.cat(tensor_list, dim=-1).contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _reduce_scatter_along_first_dim(input_,
|
||||
group,
|
||||
input_split_sizes=None,
|
||||
use_global_buffer=False):
|
||||
"""Reduce-scatter the input tensor across model parallel group.
|
||||
|
||||
Args:
|
||||
input_ (torch.Tensor): The input tensor to be reduce-scattered.
|
||||
input_split_sizes (List[int], optional): A list specifying the sizes of
|
||||
the input splits along the first dimension for each rank. If None,
|
||||
equal splitting is assumed. Default: None.
|
||||
"""
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
if input_split_sizes is None:
|
||||
dim_size = list(input_.size())
|
||||
assert (
|
||||
dim_size[0] % world_size == 0
|
||||
), "First dimension of the tensor should be divisible by tensor parallel size"
|
||||
|
||||
dim_size[0] = dim_size[0] // world_size
|
||||
|
||||
output = torch.empty(dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device())
|
||||
torch.distributed.reduce_scatter_tensor(output,
|
||||
input_.contiguous(),
|
||||
group=group)
|
||||
else:
|
||||
rank = torch.distributed.get_rank(group)
|
||||
input_tensor_list = list(torch.split(input_, input_split_sizes, dim=0))
|
||||
|
||||
output = torch.empty_like(input_tensor_list[rank])
|
||||
torch.distributed.reduce_scatter(output,
|
||||
input_tensor_list,
|
||||
group=group)
|
||||
return output
|
||||
|
||||
|
||||
def _reduce_scatter_along_last_dim(input_, group):
|
||||
"""Reduce-scatter tensors on the last dimension."""
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
target_shape = list(input_.size())
|
||||
target_shape[-1] = target_shape[-1] // world_size
|
||||
input_ = input_.reshape(-1, input_.shape[-1])
|
||||
split_tensors = torch.split(input_,
|
||||
split_size_or_sections=input_.shape[-1] //
|
||||
world_size,
|
||||
dim=1)
|
||||
concat_tensor = torch.cat(split_tensors, dim=0)
|
||||
output = _reduce_scatter_along_first_dim(concat_tensor,
|
||||
group).reshape(target_shape)
|
||||
return output
|
||||
|
||||
|
||||
def all_gather_last_dim_from_tensor_parallel_region(input_, group):
|
||||
"""Wrapper for autograd function: forward: AG, backward RS <last dim>"""
|
||||
return _gather_along_last_dim(input_, group)
|
||||
|
||||
|
||||
def reduce_scatter_to_sequence_parallel_region(input_,
|
||||
group,
|
||||
input_split_sizes=None):
|
||||
"""Wrapper for autograd function: forward: RS, backward AG <first dim>"""
|
||||
return _reduce_scatter_along_first_dim(input_, group, input_split_sizes)
|
||||
|
||||
|
||||
def reduce_scatter_last_dim_to_tensor_parallel_region(input_, group):
|
||||
"""Wrapper for autograd function: forward: RS, backward AG: AG <last dim>"""
|
||||
return _reduce_scatter_along_last_dim(input_, group)
|
||||
|
||||
|
||||
def gather_from_sequence_parallel_region(
|
||||
input_,
|
||||
group,
|
||||
output_split_sizes=None,
|
||||
):
|
||||
"""Wrapper for autograd function: forward: AG, backward: RS <first dim>"""
|
||||
return _gather_along_first_dim(input_, group, output_split_sizes)
|
||||
|
||||
|
||||
def all_to_all(group, input, output_split_sizes=None, input_split_sizes=None):
|
||||
world_size = torch.distributed.get_world_size(group=group)
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input
|
||||
|
||||
input = input.contiguous()
|
||||
if output_split_sizes is None:
|
||||
# Equal split (all2all)
|
||||
output = torch.empty_like(input)
|
||||
else:
|
||||
# Unequal split (all2all-v)
|
||||
output = input.new_empty(
|
||||
size=[sum(output_split_sizes)] + list(input.size()[1:]),
|
||||
dtype=input.dtype,
|
||||
device=torch.npu.current_device(),
|
||||
)
|
||||
torch.distributed.all_to_all_single(
|
||||
output,
|
||||
input,
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=group,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def all_to_all_sp2hp(input_, group):
|
||||
"""
|
||||
Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape
|
||||
[num_tokens/TP, H] to [num_tokens, H/TP].
|
||||
|
||||
Args:
|
||||
input_ (torch.Tensor):
|
||||
The input tensor which has been distributed along the sequence
|
||||
dimension.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output tensor with shape [num_tokens, H/TP].
|
||||
|
||||
"""
|
||||
if group is None:
|
||||
return input_
|
||||
world_size = torch.distributed.get_world_size(group=group)
|
||||
tp_group = group
|
||||
input_ = input_.reshape(-1, input_.shape[-1])
|
||||
split_tensors = torch.split(input_,
|
||||
split_size_or_sections=input_.shape[-1] //
|
||||
world_size,
|
||||
dim=1)
|
||||
concat_tensor = torch.cat(split_tensors, dim=0)
|
||||
output = all_to_all(tp_group, concat_tensor)
|
||||
return output
|
||||
|
||||
|
||||
def all_to_all_hp2sp(input_, group):
|
||||
"""
|
||||
Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape
|
||||
[num_tokens, H/TP] to [num_tokens/TP, H].
|
||||
|
||||
Args:
|
||||
input_ (torch.Tensor):
|
||||
The input tensor which has been distributed along the hidden
|
||||
dimension.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output tensor with shape [num_tokens/TP, H].
|
||||
"""
|
||||
if group is None:
|
||||
return input_
|
||||
world_size = torch.distributed.get_world_size(group=group)
|
||||
input_ = input_.reshape(-1, input_.shape[-1])
|
||||
tp_group = group
|
||||
input_exchanged = all_to_all(tp_group, input_)
|
||||
input_reshaped = input_exchanged.reshape(-1, input_exchanged.shape[-1])
|
||||
split_tensors = torch.split(
|
||||
input_reshaped,
|
||||
split_size_or_sections=input_reshaped.shape[0] // world_size,
|
||||
dim=0)
|
||||
output = torch.cat(split_tensors, dim=-1)
|
||||
return output
|
||||
160
vllm_ascend/envs.py
Normal file
160
vllm_ascend/envs.py
Normal file
@@ -0,0 +1,160 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# This file is mainly Adapted from vllm-project/vllm/vllm/envs.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
# The begin-* and end* here are used by the documentation generator
|
||||
# to extract the used env vars.
|
||||
|
||||
# begin-env-vars-definition
|
||||
|
||||
env_variables: Dict[str, Callable[[], Any]] = {
|
||||
# max compile thread number for package building. Usually, it is set to
|
||||
# the number of CPU cores. If not set, the default value is None, which
|
||||
# means all number of CPU cores will be used.
|
||||
"MAX_JOBS":
|
||||
lambda: os.getenv("MAX_JOBS", None),
|
||||
# The build type of the package. It can be one of the following values:
|
||||
# Release, Debug, RelWithDebugInfo. If not set, the default value is Release.
|
||||
"CMAKE_BUILD_TYPE":
|
||||
lambda: os.getenv("CMAKE_BUILD_TYPE"),
|
||||
# Whether to compile custom kernels. If not set, the default value is True.
|
||||
# If set to False, the custom kernels will not be compiled. Please note that
|
||||
# the sleep mode feature will be disabled as well if custom kernels are not
|
||||
# compiled.
|
||||
"COMPILE_CUSTOM_KERNELS":
|
||||
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
|
||||
# The CXX compiler used for compiling the package. If not set, the default
|
||||
# value is None, which means the system default CXX compiler will be used.
|
||||
"CXX_COMPILER":
|
||||
lambda: os.getenv("CXX_COMPILER", None),
|
||||
# The C compiler used for compiling the package. If not set, the default
|
||||
# value is None, which means the system default C compiler will be used.
|
||||
"C_COMPILER":
|
||||
lambda: os.getenv("C_COMPILER", None),
|
||||
# The version of the Ascend chip. If not set, the default value is
|
||||
# ASCEND910B1(Available for A2 and A3 series). It's used for package building.
|
||||
# Please make sure that the version is correct.
|
||||
"SOC_VERSION":
|
||||
lambda: os.getenv("SOC_VERSION", "ASCEND910B1"),
|
||||
# If set, vllm-ascend will print verbose logs during compilation
|
||||
"VERBOSE":
|
||||
lambda: bool(int(os.getenv('VERBOSE', '0'))),
|
||||
# The home path for CANN toolkit. If not set, the default value is
|
||||
# /usr/local/Ascend/ascend-toolkit/latest
|
||||
"ASCEND_HOME_PATH":
|
||||
lambda: os.getenv("ASCEND_HOME_PATH", None),
|
||||
# The path for HCCL library, it's used by pyhccl communicator backend. If
|
||||
# not set, the default value is libhccl.so。
|
||||
"HCCL_SO_PATH":
|
||||
lambda: os.environ.get("HCCL_SO_PATH", None),
|
||||
# The version of vllm is installed. This value is used for developers who
|
||||
# installed vllm from source locally. In this case, the version of vllm is
|
||||
# usually changed. For example, if the version of vllm is "0.9.0", but when
|
||||
# it's installed from source, the version of vllm is usually set to "0.9.1".
|
||||
# In this case, developers need to set this value to "0.9.0" to make sure
|
||||
# that the correct package is installed.
|
||||
"VLLM_VERSION":
|
||||
lambda: os.getenv("VLLM_VERSION", None),
|
||||
# Whether to enable the trace recompiles from pytorch.
|
||||
"VLLM_ASCEND_TRACE_RECOMPILES":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))),
|
||||
# Whether to enable fused_experts_allgather_ep. MoeInitRoutingV3 and
|
||||
# GroupedMatmulFinalizeRouting operators are combined to implement EP.
|
||||
"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP":
|
||||
lambda: bool(int(os.getenv("VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP", '0'))
|
||||
),
|
||||
# Whether to enable DBO feature for deepseek model.
|
||||
"VLLM_ASCEND_ENABLE_DBO":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DBO", '0'))),
|
||||
# Whether to enable the model execute time observe profile. Disable it when
|
||||
# running vllm ascend in production environment.
|
||||
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0'))
|
||||
),
|
||||
# Some models are optimized by vllm ascend. While in some case, e.g. rlhf
|
||||
# training, the optimized model may not be suitable. In this case, set this
|
||||
# value to False to disable the optimized model.
|
||||
"USE_OPTIMIZED_MODEL":
|
||||
lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))),
|
||||
# The tolerance of the kv cache size, if the difference between the
|
||||
# actual kv cache size and the cached kv cache size is less than this value,
|
||||
# then the cached kv cache size will be used.
|
||||
"VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE":
|
||||
lambda: int(
|
||||
os.getenv("VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE", 64)),
|
||||
# Whether to enable the topk optimization. It's enabled by default. Please set to False if you hit any issue.
|
||||
# We'll remove this flag in the future once it's stable enough.
|
||||
"VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION":
|
||||
lambda: bool(
|
||||
int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION", '1'))),
|
||||
# `LLMDataDistCMgrConnector` required variable. `DISAGGREGATED_PREFILL_RANK_TABLE_PATH` is
|
||||
# used for llmdatadist to build the communication topology for kv cache transfer, it is
|
||||
# a required variable if `LLMDataDistCMgrConnector` is used as kv connector for disaggregated
|
||||
# pd. The rank table can be generated by adopting the script `gen_ranktable.sh`
|
||||
# in vllm_ascend's example folder.
|
||||
"DISAGGREGATED_PREFILL_RANK_TABLE_PATH":
|
||||
lambda: os.getenv("DISAGGREGATED_PREFILL_RANK_TABLE_PATH", None),
|
||||
# `LLMDataDistCMgrConnector` required variable. `VLLM_ASCEND_LLMDD_RPC_IP` is used as the
|
||||
# rpc communication listening ip, which will be used to receive the agent metadata from the
|
||||
# remote worker.
|
||||
"VLLM_ASCEND_LLMDD_RPC_IP":
|
||||
lambda: os.getenv("VLLM_ASCEND_LLMDD_RPC_IP", "0.0.0.0"),
|
||||
# `LLMDataDistCMgrConnector` required variable. `VLLM_ASCEND_LLMDD_RPC_PORT` is used as the
|
||||
# rpc communication listening port, which will be used to receive the agent metadata from the
|
||||
# remote worker.
|
||||
"VLLM_ASCEND_LLMDD_RPC_PORT":
|
||||
lambda: int(os.getenv("VLLM_ASCEND_LLMDD_RPC_PORT", 5557)),
|
||||
# Whether to enable mla_pa for deepseek mla decode, this flag will be removed after its available torch_npu is public accessible
|
||||
# and the mla_pa will be the default path of deepseek decode path.
|
||||
"VLLM_ASCEND_MLA_PA":
|
||||
lambda: int(os.getenv("VLLM_ASCEND_MLA_PA", 0)),
|
||||
# Whether to enable MatmulAllReduce fusion kernel when tensor parallel is enabled.
|
||||
# this feature is supported in A2, and eager mode will get better performance.
|
||||
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))),
|
||||
# Whether to enable mlp optimize when tensor parallel is enabled.
|
||||
# this feature in eager mode will get better performance.
|
||||
"VLLM_ASCEND_ENABLE_MLP_OPTIMIZE":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLP_OPTIMIZE", '0'))),
|
||||
# Determine the number of physical devices in a non-full-use scenario
|
||||
# caused by the initialization of the Mooncake connector.
|
||||
"PHYSICAL_DEVICES":
|
||||
lambda: os.getenv("PHYSICAL_DEVICES", None),
|
||||
# Timeout (in seconds) for delayed KVCache block release. In the prefill
|
||||
# node, if a request is marked for delayed KV block release and the blocks
|
||||
# are not freed within this timeout, they will be forcibly released.
|
||||
"VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT":
|
||||
lambda: int(os.getenv("VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT", 250)),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
# lazy evaluation of environment variables
|
||||
if name in env_variables:
|
||||
return env_variables[name]()
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
def __dir__():
|
||||
return list(env_variables.keys())
|
||||
0
vllm_ascend/lora/__init__.py
Normal file
0
vllm_ascend/lora/__init__.py
Normal file
0
vllm_ascend/lora/punica_wrapper/__init__.py
Normal file
0
vllm_ascend/lora/punica_wrapper/__init__.py
Normal file
112
vllm_ascend/lora/punica_wrapper/lora_ops.py
Normal file
112
vllm_ascend/lora/punica_wrapper/lora_ops.py
Normal file
@@ -0,0 +1,112 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def bgmv_shrink(inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0):
|
||||
return torch.ops._C.bgmv_shrink(
|
||||
inputs,
|
||||
lora_a_weights,
|
||||
lora_indices_tensor,
|
||||
output_tensor,
|
||||
scaling,
|
||||
)
|
||||
|
||||
|
||||
def bgmv_expand(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True):
|
||||
return torch.ops._C.bgmv_expand(
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
lora_indices_tensor,
|
||||
output_tensor,
|
||||
0,
|
||||
output_tensor.size(1),
|
||||
)
|
||||
|
||||
|
||||
def bgmv_expand_slice(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True):
|
||||
return torch.ops._C.bgmv_expand(inputs, lora_b_weights,
|
||||
lora_indices_tensor, output_tensor,
|
||||
slice_offset, slice_size)
|
||||
|
||||
|
||||
def sgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
scaling: float,
|
||||
):
|
||||
return torch.ops._C.sgmv_shrink(inputs, lora_a_weights,
|
||||
lora_indices_tensor, seq_len_tensor,
|
||||
output_tensor, scaling)
|
||||
|
||||
|
||||
def sgmv_expand(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
add_inputs: bool = False):
|
||||
return torch.ops._C.sgmv_expand(
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
output_tensor,
|
||||
0,
|
||||
output_tensor.size(1),
|
||||
)
|
||||
|
||||
|
||||
def sgmv_expand_slice(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = False):
|
||||
return torch.ops._C.sgmv_expand(inputs, lora_b_weights,
|
||||
lora_indices_tensor, seq_len_tensor,
|
||||
output_tensor, slice_offset, slice_size)
|
||||
364
vllm_ascend/lora/punica_wrapper/punica_npu.py
Normal file
364
vllm_ascend/lora/punica_wrapper/punica_npu.py
Normal file
@@ -0,0 +1,364 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
|
||||
if is_310p():
|
||||
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
|
||||
bgmv_shrink, sgmv_expand,
|
||||
sgmv_expand_slice, sgmv_shrink)
|
||||
else:
|
||||
from vllm_ascend.lora.punica_wrapper.lora_ops import (
|
||||
bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand,
|
||||
sgmv_expand_slice, sgmv_shrink)
|
||||
|
||||
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
|
||||
|
||||
|
||||
# The platforms that are compatible with the PyTorch-native implementation can
|
||||
# inherit this class
|
||||
class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
"""
|
||||
PunicaWrapperNPU is designed to manage and provide metadata for the punica
|
||||
kernel. The main function is to maintain the state information for
|
||||
Multi-LoRA, and to provide the interface for the pytorch punica ops.
|
||||
"""
|
||||
|
||||
def __init__(self, max_num_batched_tokens: int, max_batches: int,
|
||||
device: Union[torch.device, str], **kwargs):
|
||||
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
|
||||
device)
|
||||
|
||||
def _shrink_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
sgmv_shrink(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
scale,
|
||||
)
|
||||
|
||||
def _shrink_decode(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
|
||||
|
||||
def _expand_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
add_inputs: bool,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
sgmv_expand(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
add_inputs,
|
||||
)
|
||||
|
||||
def _expand_decode(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
add_inputs: bool,
|
||||
):
|
||||
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs)
|
||||
|
||||
def _expand_slice_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
y_offset: int,
|
||||
y_slice_size: int,
|
||||
add_inputs: bool,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
sgmv_expand_slice(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
y_offset,
|
||||
y_slice_size,
|
||||
add_inputs,
|
||||
)
|
||||
|
||||
def _expand_slice_decode(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
y_offset: int,
|
||||
y_slice_size: int,
|
||||
add_inputs: bool,
|
||||
):
|
||||
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
|
||||
y_slice_size, add_inputs)
|
||||
|
||||
def _apply_expand(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
y_offset: int,
|
||||
y_slice_size: int,
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
"""
|
||||
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
|
||||
computation, which is suitable for the
|
||||
GEMM of lora'b.
|
||||
"""
|
||||
|
||||
expand_slice_fun: Callable = (self._expand_slice_prefill
|
||||
if self.is_prefill else
|
||||
self._expand_slice_decode)
|
||||
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
|
||||
|
||||
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor,
|
||||
w_t_all: torch.Tensor, scale: float):
|
||||
"""
|
||||
Perform the ` y+=x@w_t_all` computation, which is suitable for the
|
||||
GEMM of lora'a.
|
||||
When `is_prefill is` true, it indicates that it is currently the
|
||||
prefill stage, and the `_shrink_prefill` function should be called.
|
||||
Otherwise, it is the decode stage, and the _shrink_decode function
|
||||
should be called.
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
shrink_fun: Callable = (self._shrink_prefill
|
||||
if self.is_prefill else self._shrink_decode)
|
||||
shrink_fun(y, x, w_t_all, scale)
|
||||
y = y.view_as(y_org)
|
||||
|
||||
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
||||
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||
scale: float, **kwargs):
|
||||
"""
|
||||
Performs GEMM for multiple slices of lora_a.
|
||||
When `is_prefill is` true, it indicates that it is currently the
|
||||
prefill stage, and the `_shrink_prefill` function should be called.
|
||||
Otherwise, it is the decode stage, and the _shrink_decode function
|
||||
should be called.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_a_stacked)):
|
||||
y[i] += (x @ lora_a_stacked[i]) * scale
|
||||
|
||||
Args:
|
||||
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
|
||||
x (torch.Tensor): Input tensor
|
||||
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
|
||||
scale (float): Scaling factor for the operation
|
||||
"""
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
# TODO fuse these kernels
|
||||
for slice_idx in range(len(lora_a_stacked)):
|
||||
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
|
||||
scale)
|
||||
|
||||
def add_expand(self,
|
||||
y: torch.Tensor,
|
||||
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
||||
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
||||
output_slices: Tuple[int, ...],
|
||||
offset_start: int = 0,
|
||||
add_inputs=True,
|
||||
**kwargs) -> None:
|
||||
"""
|
||||
Performs GEMM and bias addition for multiple slices of lora_b.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_b_stacked)):
|
||||
slice = output_slices[i]
|
||||
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
|
||||
lora_bias_stacked[i]
|
||||
offset += slice
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
|
||||
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
|
||||
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
|
||||
bias's weight
|
||||
output_slices (Tuple[int, ...]): Every slice's size
|
||||
add_inputs (bool): Defaults to True.
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
offset_left = offset_start
|
||||
if lora_bias_stacked is not None:
|
||||
self._apply_bias(self.token_lora_indices, y, output_slices,
|
||||
lora_bias_stacked)
|
||||
for slice_idx in range(len(lora_b_stacked)):
|
||||
self._apply_expand(
|
||||
y,
|
||||
x[slice_idx],
|
||||
lora_b_stacked[slice_idx],
|
||||
offset_left,
|
||||
output_slices[slice_idx],
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
offset_left += output_slices[slice_idx]
|
||||
y = y.view_as(y_org)
|
||||
|
||||
def add_lora_embedding(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
**kwargs) -> None:
|
||||
"""
|
||||
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
||||
|
||||
Semantics:
|
||||
y += x @ lora_b_stacked
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (torch.Tensor): Input tensor.
|
||||
lora_b_stacked (torch.Tensor): lora_b's weights.
|
||||
add_inputs (bool): Default to True.
|
||||
"""
|
||||
|
||||
# Embedding layer only need expand op
|
||||
expand_fun: Callable = (self._expand_prefill
|
||||
if self.is_prefill else self._expand_decode)
|
||||
expand_fun(y, x, lora_b_stacked, add_inputs)
|
||||
|
||||
def add_lora_linear(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
||||
scale: float,
|
||||
output_slices: Tuple[int, ...],
|
||||
*,
|
||||
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
|
||||
**kwargs) -> None:
|
||||
"""
|
||||
Applicable to linear-related lora.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_a_stacked)):
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
||||
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
||||
* scale
|
||||
).squeeze(0)+lora_bias_stacked[i]
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor. Will be changed in-place.
|
||||
x (torch.Tensor): Input tensor
|
||||
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
|
||||
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
|
||||
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
|
||||
scale (float): Scaling factor.
|
||||
output_slices (Tuple[int, ...]): Every slice's size.
|
||||
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
|
||||
"""
|
||||
|
||||
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
||||
if lora_bias_stacked is not None:
|
||||
assert len(lora_bias_stacked) == len(output_slices)
|
||||
y = self._apply_bias(self.token_lora_indices, y, output_slices,
|
||||
lora_bias_stacked)
|
||||
|
||||
if buffer is None:
|
||||
r = lora_b_stacked[0].size(-1)
|
||||
# We set the buffer to be float32 by default, consistent with the
|
||||
# triton op
|
||||
buffer = tuple(
|
||||
torch.zeros(
|
||||
(x.size(0), r), dtype=torch.float32, device=x.device)
|
||||
for _ in range(len(output_slices)))
|
||||
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
|
||||
self.add_expand(y,
|
||||
buffer,
|
||||
lora_b_stacked,
|
||||
None,
|
||||
output_slices,
|
||||
add_inputs=True,
|
||||
**kwargs)
|
||||
|
||||
def add_lora_logits(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
scale,
|
||||
*,
|
||||
buffer: Optional[torch.Tensor] = None,
|
||||
**kwargs) -> None:
|
||||
"""
|
||||
Applies lora specifically for LogitsProcessorWithLoRA.
|
||||
|
||||
Semantics:
|
||||
buffer = (x @ lora_a_stacked) * scale
|
||||
y += buffer @ lora_b_stacked
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (torch.Tensor): Input tensor.
|
||||
lora_a_stacked (torch.Tensor): lora_a's weights.
|
||||
lora_b_stacked (torch.Tensor):lora_b's weights.
|
||||
scale (float): Scaling factor.
|
||||
buffer (Optional[torch.Tensor]):Default to None.
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
x = x.view(-1, x.shape[-1])
|
||||
|
||||
if lora_a_stacked.dim() == 2:
|
||||
lora_a_stacked = lora_a_stacked.unsqueeze(0)
|
||||
if lora_b_stacked.dim() == 2:
|
||||
lora_b_stacked = lora_b_stacked.unsqueeze(0)
|
||||
|
||||
r = lora_a_stacked.size(-1)
|
||||
|
||||
if buffer is None:
|
||||
buffer = torch.zeros((x.size(0), r),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
|
||||
indices = self.sampler_indices
|
||||
if indices.max() >= lora_a_stacked.size(0):
|
||||
indices = torch.clamp(indices, 0, lora_a_stacked.size(0) - 1)
|
||||
|
||||
lora_a_reshaped = lora_a_stacked.transpose(1, 2)
|
||||
lora_b_reshaped = lora_b_stacked.transpose(1, 2)
|
||||
|
||||
bgmv_shrink(x, lora_a_reshaped, buffer, indices, scale)
|
||||
bgmv_expand(buffer, lora_b_reshaped, y, indices, add_inputs=True)
|
||||
|
||||
y = y.view_as(y_org)
|
||||
104
vllm_ascend/meta_registration.py
Normal file
104
vllm_ascend/meta_registration.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import torch
|
||||
from torch.library import Library
|
||||
|
||||
# This file provides a template and registration utilities for writing "meta" implementations
|
||||
# of custom operators in Python for the vllm_ascend project.
|
||||
#
|
||||
# We offer two ways to implement meta implementations for custom ops:
|
||||
# 1. Python meta implementation (as shown in this file): Write a Python function that
|
||||
# takes the same arguments as your operator and returns empty tensors with the correct
|
||||
# shapes and dtypes. This is useful for rapid prototyping and for ops that are only
|
||||
# used in Python.
|
||||
# 2. C++ meta implementation: You can also implement the meta function in C++ for better
|
||||
# performance or to match the C++ op logic more closely. See `torch_binding_meta.cpp`
|
||||
# for examples of C++ meta implementations and how to register them.
|
||||
#
|
||||
# Both approaches enable tracing, export, and shape inference in PyTorch and vLLM, which
|
||||
# is essential for supporting `torch.compile` and aclgraph.
|
||||
|
||||
# How to add a new meta implementation in Python:
|
||||
# -------------------------------------
|
||||
# 1. Write a Python function that takes the same arguments as your operator, and returns
|
||||
# empty tensors (using torch.empty_like, torch.empty, etc.) with the correct shapes and dtypes.
|
||||
# Do NOT perform any real computation or allocate device memory.
|
||||
#
|
||||
# 2. Register your meta function using `register_meta_if_necessary`, providing:
|
||||
# - The namespace (usually "_C" for custom ops)
|
||||
# - The operator name (as registered in C++)
|
||||
# - The Python meta function
|
||||
# - (Optional) The overload name, if your op has overloads
|
||||
#
|
||||
# 3. The registration utility will check if a meta implementation already exists for your op,
|
||||
# and only register if necessary. This avoids duplicate registrations.
|
||||
#
|
||||
# 4. Example meta implementations are provided below for rotary_embedding and get_masked_input_and_mask.
|
||||
#
|
||||
# 5. When developing new custom ops, always provide a meta implementation to enable tracing,
|
||||
# export, and shape inference in PyTorch and vLLM to enable the capture of `torch.compile`
|
||||
# and aclgraph.
|
||||
#
|
||||
# For more details, see: https://pytorch.org/docs/stable/notes/extending.html#meta-tensors
|
||||
|
||||
lib = Library("_C", "IMPL")
|
||||
|
||||
|
||||
def register_meta_if_necessary(ns: str, op_name: str, fn, overload: str = ""):
|
||||
if overload != "":
|
||||
op_name = op_name + "." + overload
|
||||
schema_to_find = ns + "::" + op_name
|
||||
meta_impl_list = torch._C._dispatch_get_registrations_for_dispatch_key(
|
||||
"Meta")
|
||||
if schema_to_find in meta_impl_list:
|
||||
return
|
||||
lib.impl(op_name, fn, "Meta")
|
||||
|
||||
|
||||
def rotary_embedding_meta(positions: torch.Tensor, query: torch.Tensor,
|
||||
key: torch.Tensor, head_size: int,
|
||||
cos_sin_cache: torch.Tensor, is_neox: bool):
|
||||
|
||||
num_tokens = positions.numel()
|
||||
query_hidden_size = query.numel() // num_tokens
|
||||
key_hidden_size = key.numel() // num_tokens
|
||||
num_heads = query_hidden_size // head_size
|
||||
num_kv_heads = key_hidden_size // head_size
|
||||
|
||||
query_dst = torch.empty_like(query).view(num_tokens, num_heads, head_size)
|
||||
key_dst = torch.empty_like(key).view(num_tokens, num_kv_heads, head_size)
|
||||
return query_dst, key_dst
|
||||
|
||||
|
||||
def get_masked_input_and_mask_meta(input: torch.Tensor,
|
||||
org_vocab_start_index: int,
|
||||
org_vocab_end_index: int,
|
||||
num_org_vocab_padding: int,
|
||||
added_vocab_start_index: int,
|
||||
added_vocab_end_index: int):
|
||||
|
||||
masked_input = torch.empty_like(input)
|
||||
mask = torch.empty_like(input).to(torch.bool)
|
||||
|
||||
return masked_input, mask
|
||||
|
||||
|
||||
def bgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor,
|
||||
indices: torch.Tensor, y: torch.Tensor, slice_offset: int,
|
||||
slice_size: int):
|
||||
|
||||
y_out = torch.empty_like(y)
|
||||
return y_out
|
||||
|
||||
|
||||
def sgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor,
|
||||
lora_indices: torch.Tensor, seq_len: torch.Tensor,
|
||||
y: torch.Tensor, slice_offset: int, slice_size: int):
|
||||
|
||||
y_out = torch.empty_like(y)
|
||||
return y_out
|
||||
|
||||
|
||||
register_meta_if_necessary("_C", "rotary_embedding", rotary_embedding_meta)
|
||||
register_meta_if_necessary("_C", "get_masked_input_and_mask",
|
||||
get_masked_input_and_mask_meta)
|
||||
register_meta_if_necessary("_C", "bgmv_expand", bgmv_expand_meta)
|
||||
register_meta_if_necessary("_C", "sgmv_expand", sgmv_expand_meta)
|
||||
61
vllm_ascend/models/__init__.py
Normal file
61
vllm_ascend/models/__init__.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from vllm import ModelRegistry
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
|
||||
|
||||
def register_model():
|
||||
from .deepseek_dbo import CustomDeepseekDBOForCausalLM # noqa: F401
|
||||
from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401
|
||||
from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401
|
||||
from .deepseek_v3 import CustomDeepseekV3ForCausalLM # noqa: F401
|
||||
from .qwen2_5_vl import \
|
||||
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
|
||||
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
|
||||
from .qwen3 import CustomQwen3ForCausalLM # noqa: F401
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepSeekMTPModel",
|
||||
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen2VLForConditionalGeneration",
|
||||
"vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration")
|
||||
|
||||
if envs_ascend.USE_OPTIMIZED_MODEL:
|
||||
ModelRegistry.register_model(
|
||||
"Qwen2_5_VLForConditionalGeneration",
|
||||
"vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration"
|
||||
)
|
||||
else:
|
||||
ModelRegistry.register_model(
|
||||
"Qwen2_5_VLForConditionalGeneration",
|
||||
"vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen2_5_VLForConditionalGeneration_Without_Padding"
|
||||
)
|
||||
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_DBO:
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV2ForCausalLM",
|
||||
"vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV3ForCausalLM",
|
||||
"vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
|
||||
else:
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV2ForCausalLM",
|
||||
"vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV3ForCausalLM",
|
||||
"vllm_ascend.models.deepseek_v3:CustomDeepseekV3ForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3MoeForCausalLM",
|
||||
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"PanguProMoEForCausalLM",
|
||||
"vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM")
|
||||
1046
vllm_ascend/models/deepseek_dbo.py
Normal file
1046
vllm_ascend/models/deepseek_dbo.py
Normal file
File diff suppressed because it is too large
Load Diff
218
vllm_ascend/models/deepseek_mtp.py
Normal file
218
vllm_ascend/models/deepseek_mtp.py
Normal file
@@ -0,0 +1,218 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Adapted from vllm/model_executor/models/deepseek_mtp.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.deepseek_mtp import (
|
||||
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
|
||||
SharedHead)
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .deepseek_v2 import CustomDeepseekV2DecoderLayer
|
||||
|
||||
|
||||
class CustomDeepSeekShareHead(SharedHead):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "head"))
|
||||
|
||||
|
||||
class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
prefix: str,
|
||||
model_config: ModelConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
self.shared_head = CustomDeepSeekShareHead(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "shared_head"))
|
||||
self.mtp_block = CustomDeepseekV2DecoderLayer(config, prefix,
|
||||
model_config,
|
||||
cache_config,
|
||||
quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_index: int = 0,
|
||||
) -> torch.Tensor:
|
||||
assert inputs_embeds is not None
|
||||
# masking inputs at position 0, as not needed by MTP
|
||||
inputs_embeds = torch.where((positions == 0).unsqueeze(-1),
|
||||
torch.zeros_like(inputs_embeds),
|
||||
inputs_embeds)
|
||||
inputs_embeds = self.enorm(inputs_embeds)
|
||||
previous_hidden_states = self.hnorm(previous_hidden_states)
|
||||
|
||||
hidden_states = self.eh_proj(
|
||||
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
||||
|
||||
hidden_states, residual = self.mtp_block(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=None)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.mtp_start_layer_idx = config.num_hidden_layers
|
||||
self.num_mtp_layers = config.num_nextn_predict_layers
|
||||
# to map the exact layer index from weights
|
||||
self.layers = torch.nn.ModuleDict({
|
||||
str(idx):
|
||||
CustomDeepSeekMultiTokenPredictorLayer(
|
||||
config,
|
||||
f"{prefix}.layers.{idx}",
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
quant_config=vllm_config.quant_config,
|
||||
)
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
})
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
|
||||
# Note: torch._dynamo.exc.Unsupported: builtin: str
|
||||
self.layers_list = [
|
||||
self.layers[str(idx)]
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
]
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
step_kv_cache = kv_caches[
|
||||
current_step_idx] if kv_caches is not None else None
|
||||
return self.layers_list[current_step_idx](
|
||||
input_ids,
|
||||
positions,
|
||||
step_kv_cache,
|
||||
attn_metadata,
|
||||
previous_hidden_states,
|
||||
inputs_embeds,
|
||||
current_step_idx,
|
||||
)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
mtp_layer = self.layers_list[current_step_idx]
|
||||
logits = self.logits_processor(mtp_layer.shared_head.head,
|
||||
mtp_layer.shared_head(hidden_states),
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
||||
class CustomDeepSeekMTP(DeepSeekMTP):
|
||||
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
|
||||
# NOTE 2.The description file generated by the current msmodelslim tool does not have
|
||||
# MTP layer info. Please manually add it and set the value to FLOAT.
|
||||
packed_modules_mapping = {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.model = CustomDeepSeekMultiTokenPredictor(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "model"))
|
||||
|
||||
self.sampler = get_sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
previous_hidden_states: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, previous_hidden_states,
|
||||
inputs_embeds, spec_step_idx)
|
||||
return hidden_states
|
||||
997
vllm_ascend/models/deepseek_v2.py
Normal file
997
vllm_ascend/models/deepseek_v2.py
Normal file
@@ -0,0 +1,997 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# # Adapted from
|
||||
# # vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py
|
||||
# # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
|
||||
# """Inference-only DeepseekV2/DeepseekV3 model."""
|
||||
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
||||
get_current_vllm_config)
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group, split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_reduce_scatter)
|
||||
from vllm.distributed.parallel_state import get_dp_group, get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.deepseek_v2 import \
|
||||
DeepseekV2ForCausalLM # noqa: E501
|
||||
from vllm.model_executor.models.deepseek_v2 import \
|
||||
yarn_get_mscale # noqa: E501
|
||||
from vllm.model_executor.models.deepseek_v2 import (
|
||||
DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2MLAAttention,
|
||||
get_spec_layer_idx_from_weight_name)
|
||||
from vllm.model_executor.models.utils import (
|
||||
PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
||||
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
||||
from vllm_ascend.utils import dispose_tensor
|
||||
|
||||
|
||||
class CustomDeepseekV2SiluAndMul(SiluAndMul):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
weight_scale: Optional[Callable[[], torch.Tensor]] = None):
|
||||
super().__init__()
|
||||
self.weight_scale = weight_scale
|
||||
|
||||
def forward_oot(self, x: Union[torch.Tensor, Tuple[torch.Tensor,
|
||||
torch.Tensor]]):
|
||||
if isinstance(x, tuple):
|
||||
assert self.weight_scale is not None
|
||||
# For AscendW8A8DynamicLinearMethod:
|
||||
# a dynamic scale is passed along with the quantized value.
|
||||
quantized_x, dynamic_scale = x
|
||||
return torch_npu.npu_dequant_swiglu_quant(
|
||||
x=quantized_x,
|
||||
weight_scale=self.weight_scale(),
|
||||
activation_scale=dynamic_scale,
|
||||
activate_left=True,
|
||||
quant_mode=1)
|
||||
else:
|
||||
return super().forward_oot(x)
|
||||
|
||||
|
||||
class CustomDeepseekV2MergedReplicatedLinear(ReplicatedLinear):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_sizes: list[int],
|
||||
bias: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
self.output_sizes = output_sizes
|
||||
super().__init__(input_size,
|
||||
sum(output_sizes),
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
def weight_loader(self, param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor, loaded_shard_id: int):
|
||||
# With no support for GGUF format yet.
|
||||
assert not getattr(param, "is_gguf_weight", False)
|
||||
assert not getattr(param, "is_gguf_weight_type", False)
|
||||
|
||||
assert loaded_shard_id < len(self.output_sizes)
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id])
|
||||
shard_size = self.output_sizes[loaded_shard_id]
|
||||
shard = param.data.narrow(param.output_dim, shard_offset, shard_size)
|
||||
|
||||
assert shard.size() == loaded_weight.size(), (
|
||||
f"Tried to load weights of size {loaded_weight.size()}"
|
||||
f"to a parameter shard of id {loaded_shard_id} size {shard.size()}"
|
||||
)
|
||||
shard.copy_(loaded_weight)
|
||||
|
||||
|
||||
class CustomDeepseekV2RowParallelLinearReplaceAllreduce(RowParallelLinear):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
is_prefill=True,
|
||||
is_force_scatter=False
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
output_parallel = self.quant_method.apply(self,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
num_tokens = output_parallel.shape[0]
|
||||
if is_force_scatter and num_tokens % self.tp_size:
|
||||
output_parallel = nn.functional.pad(
|
||||
output_parallel, (0, 0, 0, -num_tokens % self.tp_size))
|
||||
if is_force_scatter or (not is_prefill
|
||||
and output_parallel.shape[0] % self.tp_size
|
||||
== 0):
|
||||
output = tensor_model_parallel_reduce_scatter(output_parallel,
|
||||
dim=0)
|
||||
else:
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class CustomDeepseekV2RowParallelLinear(RowParallelLinear):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
is_prefill=True,
|
||||
is_force_scatter=False
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
output_parallel = self.quant_method.apply(self,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class CustomDeepseekV2MLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
force_replicate: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if not force_replicate:
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
else:
|
||||
self.gate_up_proj = CustomDeepseekV2MergedReplicatedLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = ReplicatedLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
|
||||
quant_method = self.gate_up_proj.quant_method
|
||||
if isinstance(quant_method, UnquantizedLinearMethod):
|
||||
self.act_fn = CustomDeepseekV2SiluAndMul()
|
||||
elif (isinstance(quant_method, AscendLinearMethod) and isinstance(
|
||||
quant_method.quant_method, AscendW8A8DynamicLinearMethod)):
|
||||
# TODO(sdmyzlp): Currently preserved as before:
|
||||
# 1. The only quantization supported for silu is W8A8Dynamic
|
||||
# 2. Output dtype of gate_up/down is fixed to be int32/bfloat16
|
||||
#
|
||||
# Maybe one can implement a better and more general configuration
|
||||
# scheme, e.g. by somehow passing around the tweaked `quant_config`
|
||||
self.act_fn = CustomDeepseekV2SiluAndMul(
|
||||
# Use lazy binding, for `weight_scale_fp32` is accessible
|
||||
# only after `process_weights_after_loading`.
|
||||
weight_scale=lambda: self.gate_up_proj.weight_scale_fp32)
|
||||
# To be consumed by AscendW8A8DynamicLinearMethod.apply()
|
||||
self.gate_up_proj._ascend_quant_config = {
|
||||
"output_dtype": torch.int32,
|
||||
"pertoken_scale": False,
|
||||
"return_scale": True,
|
||||
}
|
||||
self.down_proj._ascend_quant_config = {
|
||||
"output_dtype": torch.bfloat16,
|
||||
"pertoken_scale": True,
|
||||
"return_scale": False,
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Quantization with [{type(quant_method)}] is NOT supported")
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class CustomDeepseekV2MoE(nn.Module):
|
||||
|
||||
top_k: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.n_shared_experts = config.n_shared_experts
|
||||
if self.tp_size > config.n_routed_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.n_routed_experts}.")
|
||||
|
||||
if config.hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
self.enable_multistream_moe = \
|
||||
ascend_config.torchair_graph_config.enable_multistream_moe and \
|
||||
self.torchair_graph_enabled
|
||||
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
config.n_routed_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate")
|
||||
if config.topk_method == "noaux_tc":
|
||||
self.gate.e_score_correction_bias = nn.Parameter(
|
||||
torch.empty(config.n_routed_experts))
|
||||
else:
|
||||
self.gate.e_score_correction_bias = None
|
||||
|
||||
self.experts = AscendFusedMoE(
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
prefix=f"{prefix}.experts",
|
||||
scoring_func=config.scoring_func,
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias)
|
||||
|
||||
if config.n_shared_experts is not None:
|
||||
self.all_reduce_merge = self.experts.all_reduce_merge
|
||||
reduce_results = not self.all_reduce_merge
|
||||
intermediate_size = (config.moe_intermediate_size *
|
||||
config.n_shared_experts)
|
||||
enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
self.shared_experts = CustomDeepseekV2MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
force_replicate=self.enable_multistream_moe
|
||||
or enable_shared_expert_dp,
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
)
|
||||
else:
|
||||
self.shared_experts = None # type: ignore
|
||||
CustomDeepseekV2MoE.top_k = config.num_experts_per_tok
|
||||
|
||||
self.dp_size = get_dp_group().world_size
|
||||
|
||||
self.tp_group = get_tp_group().device_group
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
self.ep_group = get_ep_group()
|
||||
self.kv_consumer = None
|
||||
transfer_config = get_current_vllm_config().kv_transfer_config
|
||||
if transfer_config is not None:
|
||||
self.kv_consumer = transfer_config.kv_role == "kv_consumer"
|
||||
|
||||
self.params_dtype = torch.get_default_dtype()
|
||||
self.rm_router_logits = self.experts.rm_router_logits
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
replace_allreduce: bool = False) -> torch.Tensor:
|
||||
|
||||
forward_context = get_forward_context()
|
||||
# when profile runs, force experts to load balanced tokens
|
||||
# to avoid high memory consumption on a single rank.
|
||||
|
||||
enable_force_load_balance = forward_context.in_profile_run
|
||||
|
||||
is_prefill = forward_context.with_prefill
|
||||
|
||||
# If this node is kv_consumer, we force the moe always runs in decode path to make sure
|
||||
# the behaviour aligned between dummy_run and normal model_execute.
|
||||
if self.kv_consumer:
|
||||
is_prefill = False
|
||||
enable_force_load_balance = False
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = None
|
||||
if not self.rm_router_logits and not self.enable_multistream_moe:
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
experts_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
is_prefill=is_prefill,
|
||||
top_k=CustomDeepseekV2MoE.top_k,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
shared_experts=self.shared_experts,
|
||||
gate=self.gate,
|
||||
replace_allreduce=replace_allreduce)
|
||||
|
||||
hidden_states = (
|
||||
experts_hidden_states[0] * self.routed_scaling_factor +
|
||||
experts_hidden_states[1])
|
||||
if self.all_reduce_merge:
|
||||
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
q_lora_rank: Optional[int],
|
||||
kv_lora_rank: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = hidden_size
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
assert num_heads % self.tp_size == 0
|
||||
self.num_local_heads = num_heads // self.tp_size
|
||||
self.layers = config.num_hidden_layers
|
||||
self.first_k_dense_replace = config.first_k_dense_replace
|
||||
|
||||
self.scaling = self.qk_head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.prefix = prefix
|
||||
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_a_proj")
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
|
||||
self.num_heads *
|
||||
self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_b_proj")
|
||||
else:
|
||||
self.q_proj = ColumnParallelLinear(self.hidden_size,
|
||||
self.num_heads *
|
||||
self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_proj")
|
||||
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_a_proj_with_mqa")
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_b_proj")
|
||||
if (config.n_routed_experts is not None
|
||||
and self.debug_layer_idx >= config.first_k_dense_replace
|
||||
and self.debug_layer_idx % config.moe_layer_freq == 0
|
||||
and self.enable_shared_expert_dp):
|
||||
self.o_proj = CustomDeepseekV2RowParallelLinearReplaceAllreduce(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
else:
|
||||
self.o_proj = CustomDeepseekV2RowParallelLinear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
|
||||
if rope_scaling:
|
||||
rope_scaling["rope_type"] = 'deepseek_yarn'
|
||||
self.rotary_emb = get_rope(qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=False)
|
||||
if rope_scaling:
|
||||
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
# In the MLA backend, kv_cache includes both k_c and
|
||||
# pe (i.e. decoupled position embeddings). In particular,
|
||||
# the concat_and_cache_mla op requires
|
||||
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
|
||||
# i.e.
|
||||
# kv_lora_rank + qk_rope_head_dim == head_size
|
||||
self.mla_attn = Attention(
|
||||
num_heads=self.num_local_heads,
|
||||
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
scale=self.scaling,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_mla=True,
|
||||
# MLA Args
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
rotary_emb=self.rotary_emb,
|
||||
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
|
||||
q_a_layernorm=self.q_a_layernorm
|
||||
if self.q_lora_rank is not None else None,
|
||||
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
|
||||
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
||||
kv_a_layernorm=self.kv_a_layernorm,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
o_proj=self.o_proj,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
forward_context = get_forward_context()
|
||||
if kv_cache is None:
|
||||
kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine]
|
||||
num_tokens = hidden_states.shape[0]
|
||||
need_gather_q_kv = False
|
||||
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
|
||||
# Simulate all gather to calculate output shape
|
||||
num_tokens = num_tokens * self.tp_size
|
||||
need_gather_q_kv = True
|
||||
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
|
||||
output_shape = hidden_states.shape
|
||||
else:
|
||||
rows = num_tokens // self.tp_size
|
||||
if num_tokens % self.tp_size:
|
||||
rows += 1
|
||||
output_shape = (rows, hidden_states.shape[1])
|
||||
output = torch.empty(output_shape,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
output = self.mla_attn.impl.forward(hidden_states, kv_cache,
|
||||
forward_context.attn_metadata,
|
||||
need_gather_q_kv, output)
|
||||
output = output.view(-1, output_shape[-1])
|
||||
return output
|
||||
|
||||
|
||||
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
prefix: str,
|
||||
model_config: ModelConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
# DecoderLayers are created with `make_layers` which passes the prefix
|
||||
# with the layer's index.
|
||||
layer_idx = int(prefix.split(sep='.')[-1])
|
||||
self.layer_idx = layer_idx
|
||||
self.layers = config.num_hidden_layers
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
ascend_config = get_ascend_config()
|
||||
# TODO: enable mla in vllm-ascend
|
||||
if model_config.use_mla:
|
||||
attn_cls = CustomDeepseekV2MLAAttention
|
||||
else:
|
||||
attn_cls = DeepseekV2Attention
|
||||
self.self_attn = attn_cls(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
qk_nope_head_dim=config.qk_nope_head_dim,
|
||||
qk_rope_head_dim=config.qk_rope_head_dim,
|
||||
v_head_dim=config.v_head_dim,
|
||||
q_lora_rank=config.q_lora_rank
|
||||
if hasattr(config, "q_lora_rank") else None,
|
||||
kv_lora_rank=config.kv_lora_rank,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
|
||||
if (config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace
|
||||
and layer_idx % config.moe_layer_freq == 0):
|
||||
self.mlp = CustomDeepseekV2MoE(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
else:
|
||||
self.mlp = CustomDeepseekV2MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.first_k_dense_replace = config.first_k_dense_replace
|
||||
self.tp_group = get_tp_group().device_group
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
replace_allreduce: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
previous_hidden_states, previous_residual = hidden_states, residual
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
# Dispose hidden_states and residual from the previous layer
|
||||
# to save npu memory because they're no longer used.
|
||||
dispose_tensor(previous_hidden_states)
|
||||
dispose_tensor(previous_residual)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
if hidden_states.dtype == torch.float16:
|
||||
# Fix FP16 overflow
|
||||
# We scale both hidden_states and residual before
|
||||
# rmsnorm, and rmsnorm result would not affect by scale.
|
||||
hidden_states *= 1. / self.routed_scaling_factor
|
||||
if self.layer_idx == 0:
|
||||
# The residual is shared by all layers, we only scale it on
|
||||
# first layer.
|
||||
residual *= 1. / self.routed_scaling_factor
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if self.enable_shared_expert_dp and (
|
||||
self.layer_idx == self.first_k_dense_replace
|
||||
or self.layer_idx == self.layers) and tp_size > 1:
|
||||
num_tokens, _ = residual.shape
|
||||
if num_tokens % tp_size:
|
||||
residual = nn.functional.pad(residual,
|
||||
(0, 0, 0, -num_tokens % tp_size))
|
||||
chunk_residual = torch.tensor_split(residual, tp_size, dim=0)
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
residual = chunk_residual[tp_rank]
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
if isinstance(self.mlp, CustomDeepseekV2MoE):
|
||||
hidden_states = self.mlp(hidden_states, attn_metadata)
|
||||
else:
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
if isinstance(
|
||||
self.mlp,
|
||||
CustomDeepseekV2MLP) and hidden_states.dtype == torch.float16:
|
||||
# Fix FP16 overflow
|
||||
# Scaling the DeepseekV2MLP output, it is the input of
|
||||
# input_layernorm of next decoder layer.
|
||||
# The scaling of DeepseekV2MOE output would be done in the forward
|
||||
# of DeepseekV2MOE
|
||||
hidden_states *= 1. / self.routed_scaling_factor
|
||||
|
||||
# for last layer of main model and mtp layer.
|
||||
if self.enable_shared_expert_dp and self.layer_idx >= (
|
||||
self.layers - 1) and tp_size > 1:
|
||||
hidden_states = get_tp_group().all_gather(hidden_states, 0)
|
||||
residual = get_tp_group().all_gather(residual, 0)
|
||||
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata is not None:
|
||||
num_tokens = attn_metadata.num_actual_tokens
|
||||
else:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
|
||||
if num_tokens < hidden_states.shape[0]:
|
||||
hidden_states = hidden_states[:num_tokens]
|
||||
residual = residual[:num_tokens]
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class CustomDeepseekV2Model(nn.Module):
|
||||
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.embed_tokens")
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: CustomDeepseekV2DecoderLayer(
|
||||
config,
|
||||
prefix,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
replace_allreduce = hidden_states.shape[0] % self.tp_size == 0
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
kv_caches[i -
|
||||
self.start_layer] if kv_caches is not None else None,
|
||||
attn_metadata,
|
||||
replace_allreduce=replace_allreduce)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
|
||||
# add `packed_modules_mapping` in `DeepseekV2ForCausalLM` to support weight merging
|
||||
packed_modules_mapping = {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = CustomDeepseekV2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "model"))
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "lm_head"))
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = get_sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
# NOTE: This `load_weights` is mainly copied from
|
||||
# https://github.com/vllm-project/vllm/commit/07b8fae219b1fff51ef115c38c44b51395be5bb5
|
||||
# to fix CI, and it is different from the implementation in main
|
||||
# TODO: support eplb style load_weights
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
""""""
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = AscendFusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if "module" in name:
|
||||
continue
|
||||
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||
if spec_layer is not None:
|
||||
continue # skip spec decode layers for main model
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if (("mlp.experts." in name) and name not in params_dict):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=False)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
27
vllm_ascend/models/deepseek_v3.py
Normal file
27
vllm_ascend/models/deepseek_v3.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2ForCausalLM
|
||||
|
||||
|
||||
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
|
||||
pass
|
||||
1106
vllm_ascend/models/pangu_moe.py
Normal file
1106
vllm_ascend/models/pangu_moe.py
Normal file
File diff suppressed because it is too large
Load Diff
491
vllm_ascend/models/qwen2_5_vl.py
Normal file
491
vllm_ascend/models/qwen2_5_vl.py
Normal file
@@ -0,0 +1,491 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Adapted from vllm/model_executor/models/qwen2_5_vl.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from typing import Callable, Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from einops import rearrange
|
||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.qwen2_5_vl import (
|
||||
Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed,
|
||||
Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VisionTransformer,
|
||||
Qwen2_5_VLDummyInputsBuilder, Qwen2_5_VLForConditionalGeneration,
|
||||
Qwen2_5_VLMultiModalProcessor, Qwen2_5_VLProcessingInfo)
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
MIN_PAD_SIZE = 64 # min_size to pad weight
|
||||
MAX_PAD_SIZE = 128 # max_size to pad weight
|
||||
|
||||
|
||||
class AscendQwen2_5_VisionAttention(Qwen2_5_VisionAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
projection_size: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
embed_dim,
|
||||
num_heads,
|
||||
projection_size,
|
||||
quant_config,
|
||||
prefix,
|
||||
)
|
||||
self.embed_dim = embed_dim
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
projection_size, num_heads)
|
||||
self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head
|
||||
if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE:
|
||||
self.hidden_size_per_attention_head = MAX_PAD_SIZE
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
|
||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
|
||||
q, k, v = qkv.chunk(3, dim=2)
|
||||
|
||||
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||
new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head)
|
||||
q, k, v = (x.view(*new_shape) for x in (q, k, v))
|
||||
return q, k, v
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||
x, _ = self.qkv(x)
|
||||
|
||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||
q, k, v = self.split_qkv(x)
|
||||
batch_size = q.shape[1]
|
||||
|
||||
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
|
||||
for x in (q, k, v))
|
||||
q = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
|
||||
q, k, v = [
|
||||
rearrange(x, "b s h d -> (b s) h d").contiguous()
|
||||
for x in (q, k, v)
|
||||
]
|
||||
|
||||
context_layer = torch.empty_like(q)
|
||||
|
||||
# operator requires pta version >= 2.5.1
|
||||
torch_npu._npu_flash_attention_unpad(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
seq_len=cu_seqlens,
|
||||
scale_value=self.origin_hidden_size_per_attention_head**-0.5,
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
num_kv_heads=self.num_attention_heads_per_partition,
|
||||
out=context_layer)
|
||||
|
||||
context_layer = rearrange(context_layer,
|
||||
"(b s) h d -> s b (h d)",
|
||||
b=batch_size).contiguous()
|
||||
|
||||
output, _ = self.proj(context_layer)
|
||||
return output
|
||||
|
||||
|
||||
class AscendQwen2_5_VisionBlock(Qwen2_5_VisionBlock):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_hidden_dim: int,
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
|
||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer,
|
||||
quant_config, prefix)
|
||||
self.attn = AscendQwen2_5_VisionAttention(embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
|
||||
cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.attn(
|
||||
self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin)
|
||||
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class AscendQwen2_5_VisionPatchEmbed(Qwen2_5_VisionPatchEmbed):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.matmul(
|
||||
self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1))
|
||||
return x
|
||||
|
||||
|
||||
class AscendQwen2_5_VisionRotaryEmbedding(Qwen2_5_VisionRotaryEmbedding):
|
||||
|
||||
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
||||
super().__init__(dim, theta)
|
||||
inv_freq = 1.0 / (theta
|
||||
**(torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
||||
self.inv_freq = inv_freq
|
||||
|
||||
|
||||
class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config: Qwen2_5_VLVisionConfig,
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
interleaved=False,
|
||||
) -> None:
|
||||
super().__init__(vision_config, norm_eps, quant_config, prefix)
|
||||
norm_layer = partial(RMSNorm, eps=norm_eps)
|
||||
self.interleaved = interleaved
|
||||
self.enable_pad = False
|
||||
head_dim = self.hidden_size // self.num_heads
|
||||
self.rotary_pos_emb = AscendQwen2_5_VisionRotaryEmbedding(head_dim //
|
||||
2)
|
||||
self.patch_embed = AscendQwen2_5_VisionPatchEmbed(
|
||||
patch_size=vision_config.patch_size,
|
||||
temporal_patch_size=vision_config.temporal_patch_size,
|
||||
in_channels=vision_config.in_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
)
|
||||
|
||||
act_fn = get_act_and_mul_fn(vision_config.hidden_act)
|
||||
self.blocks = nn.ModuleList([
|
||||
AscendQwen2_5_VisionBlock(
|
||||
dim=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
mlp_hidden_dim=vision_config.intermediate_size,
|
||||
act_fn=act_fn,
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}")
|
||||
for layer_idx in range(vision_config.depth)
|
||||
])
|
||||
self.tp_size = parallel_state.get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
self.hidden_size, self.num_heads)
|
||||
|
||||
if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE:
|
||||
self.enable_pad = True
|
||||
self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head
|
||||
self.half_origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head // 2
|
||||
self.half_pad_hidden_size_per_attention_head = (
|
||||
MAX_PAD_SIZE - self.hidden_size_per_attention_head) // 2
|
||||
self.hidden_size_per_attention_head = MAX_PAD_SIZE
|
||||
|
||||
def cal_cos_sin(self, rotary_pos_emb):
|
||||
cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2]
|
||||
sin = rotary_pos_emb.sin()
|
||||
if self.enable_pad:
|
||||
cos = torch.nn.functional.pad(
|
||||
cos, (0, self.half_pad_hidden_size_per_attention_head))
|
||||
sin = torch.nn.functional.pad(
|
||||
sin, (0, self.half_pad_hidden_size_per_attention_head))
|
||||
|
||||
if not self.interleaved:
|
||||
cos_new = torch.cat((cos, cos), dim=-1)
|
||||
sin_new = torch.cat((sin, sin), dim=-1)
|
||||
else:
|
||||
cos_new = rearrange(torch.stack((cos, cos), dim=-1),
|
||||
"... d two -> ...(d two)",
|
||||
two=2)
|
||||
sin_new = rearrange(torch.stack((sin, sin), dim=-1),
|
||||
"... d two -> ...(d two)",
|
||||
two=2)
|
||||
cos_new = cos_new.reshape(1, -1, 1,
|
||||
self.hidden_size_per_attention_head)
|
||||
sin_new = sin_new.reshape(1, -1, 1,
|
||||
self.hidden_size_per_attention_head)
|
||||
return cos_new, sin_new
|
||||
|
||||
def pad_qkv_bias(self, bias):
|
||||
first_half = bias.reshape(
|
||||
-1, 3, self.origin_hidden_size_per_attention_head
|
||||
)[:, :, :self.half_origin_hidden_size_per_attention_head]
|
||||
second_half = bias.reshape(
|
||||
-1, 3, self.origin_hidden_size_per_attention_head
|
||||
)[:, :, self.half_origin_hidden_size_per_attention_head:]
|
||||
first_half_padded = torch.nn.functional.pad(
|
||||
first_half, (0, self.half_pad_hidden_size_per_attention_head))
|
||||
second_half_padded = torch.nn.functional.pad(
|
||||
second_half, (0, self.half_pad_hidden_size_per_attention_head))
|
||||
bias_padded = torch.cat([first_half_padded, second_half_padded], dim=2)
|
||||
bias_final = bias_padded.reshape(-1)
|
||||
return bias_final
|
||||
|
||||
def pad_qkv_weight(self, data):
|
||||
qkv_weight_first_half = data.reshape(
|
||||
-1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size
|
||||
)[:, :, :self.half_origin_hidden_size_per_attention_head, :]
|
||||
qkv_weight_second_half = data.reshape(
|
||||
-1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size
|
||||
)[:, :, self.half_origin_hidden_size_per_attention_head:, :]
|
||||
|
||||
qkv_weight_first_half_padded = torch.nn.functional.pad(
|
||||
qkv_weight_first_half,
|
||||
(0, 0, 0, self.half_pad_hidden_size_per_attention_head))
|
||||
qkv_weight_second_half_padded = torch.nn.functional.pad(
|
||||
qkv_weight_second_half,
|
||||
(0, 0, 0, self.half_pad_hidden_size_per_attention_head))
|
||||
qkv_weight_padded = torch.cat(
|
||||
[qkv_weight_first_half_padded, qkv_weight_second_half_padded],
|
||||
dim=2)
|
||||
qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size)
|
||||
return qkv_weight_final
|
||||
|
||||
def pad_proj_weight(self, data):
|
||||
out_weight = torch.nn.functional.pad(
|
||||
data.reshape(self.hidden_size, -1,
|
||||
self.half_origin_hidden_size_per_attention_head),
|
||||
(0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape(
|
||||
self.hidden_size, -1)
|
||||
return out_weight
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping: list[tuple[str, str, Union[str, int]]] = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("mlp.gate_up_proj.", "mlp.gate_proj.", 0),
|
||||
("mlp.gate_up_proj.", "mlp.up_proj.", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
if ("attn.proj.weight" in name) and self.enable_pad:
|
||||
param.data = self.pad_proj_weight(param.data)
|
||||
if ("attn.qkv.weight" in name) and self.enable_pad:
|
||||
param.data = self.pad_qkv_weight(param.data)
|
||||
if ("attn.qkv.bias" in name) and self.enable_pad:
|
||||
param.data = self.pad_qkv_bias(param.data)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
pos_ids = []
|
||||
for t, h, w in grid_thw:
|
||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
||||
hpos_ids = hpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
).permute(0, 2, 1, 3).flatten()
|
||||
wpos_ids = wpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
).permute(0, 2, 1, 3).flatten()
|
||||
pos_ids.append(
|
||||
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||
pos_ids = torch.cat(pos_ids, dim=0)
|
||||
max_grid_size = grid_thw[:, 1:].max()
|
||||
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
return rotary_pos_emb
|
||||
|
||||
def get_window_index(self, grid_thw):
|
||||
window_index: list = []
|
||||
cu_window_seqlens: list = [0]
|
||||
window_index_id = 0
|
||||
vit_merger_window_size = (self.window_size //
|
||||
self.spatial_merge_size // self.patch_size)
|
||||
|
||||
for grid_t, grid_h, grid_w in grid_thw:
|
||||
llm_grid_h = grid_h // self.spatial_merge_size
|
||||
llm_grid_w = grid_w // self.spatial_merge_size
|
||||
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
|
||||
grid_t, llm_grid_h, llm_grid_w)
|
||||
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
||||
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
||||
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
||||
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
||||
index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100)
|
||||
index_padded = index_padded.reshape(grid_t, num_windows_h,
|
||||
vit_merger_window_size,
|
||||
num_windows_w,
|
||||
vit_merger_window_size)
|
||||
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
||||
grid_t, num_windows_h * num_windows_w, vit_merger_window_size,
|
||||
vit_merger_window_size)
|
||||
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
||||
index_padded = index_padded.reshape(-1)
|
||||
index_new = index_padded[index_padded != -100]
|
||||
window_index.append(index_new + window_index_id)
|
||||
cu_seqlens_tmp = seqlens.cumsum(
|
||||
0) * self.spatial_merge_unit + cu_window_seqlens[-1]
|
||||
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
||||
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
||||
window_index = torch.cat(window_index, dim=0)
|
||||
return window_index, cu_window_seqlens
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# compute cu_seqlens
|
||||
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
|
||||
grid_thw[:,
|
||||
0]).cpu().to(torch.int32)
|
||||
|
||||
# patchify
|
||||
x = self.patch_embed(x)
|
||||
|
||||
# compute position embedding
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
|
||||
# windows attention
|
||||
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
|
||||
cu_window_seqlens = torch.tensor(
|
||||
cu_window_seqlens,
|
||||
device=x.device,
|
||||
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32)
|
||||
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
||||
cu_window_seqlens = torch.diff(cu_window_seqlens).cpu().to(torch.int32)
|
||||
seq_len, _ = x.size()
|
||||
x = x.reshape(seq_len // self.spatial_merge_unit,
|
||||
self.spatial_merge_unit, -1)
|
||||
x = x[window_index, :, :]
|
||||
x = x.reshape(seq_len, -1)
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(
|
||||
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
||||
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
||||
|
||||
cos, sin = self.cal_cos_sin(rotary_pos_emb)
|
||||
|
||||
# transformers
|
||||
x = x.unsqueeze(1)
|
||||
for layer_num, blk in enumerate(self.blocks):
|
||||
if layer_num in self.fullatt_block_indexes:
|
||||
cu_seqlens_now = cu_seqlens
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
x = blk(x, cu_seqlens=cu_seqlens_now, cos=cos, sin=sin)
|
||||
|
||||
# adapter
|
||||
x = self.merger(x)
|
||||
reverse_indices = torch.argsort(window_index)
|
||||
x = x[reverse_indices, :]
|
||||
return x
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
Qwen2_5_VLMultiModalProcessor,
|
||||
info=Qwen2_5_VLProcessingInfo,
|
||||
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
|
||||
class AscendQwen2_5_VLForConditionalGeneration(
|
||||
Qwen2_5_VLForConditionalGeneration):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.visual = AscendQwen2_5_VisionTransformer(
|
||||
vision_config=config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
|
||||
def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]:
|
||||
|
||||
grid_thw = image_input["image_grid_thw"]
|
||||
assert grid_thw.ndim == 2
|
||||
|
||||
if image_input["type"] == "image_embeds":
|
||||
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
||||
else:
|
||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
|
||||
|
||||
# Split concatenated embeddings for each image item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||
return image_embeds.split(sizes.tolist())
|
||||
|
||||
def _process_video_input(self, video_input) -> tuple[torch.Tensor, ...]:
|
||||
|
||||
grid_thw = video_input["video_grid_thw"]
|
||||
assert grid_thw.ndim == 2
|
||||
|
||||
if video_input["type"] == "video_embeds":
|
||||
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
||||
else:
|
||||
pixel_values_videos = video_input["pixel_values_videos"].type(
|
||||
self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
|
||||
|
||||
# Split concatenated embeddings for each video item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||
return video_embeds.split(sizes.tolist())
|
||||
373
vllm_ascend/models/qwen2_5_vl_without_padding.py
Normal file
373
vllm_ascend/models/qwen2_5_vl_without_padding.py
Normal file
@@ -0,0 +1,373 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Adapted from vllm/model_executor/models/qwen2_5_vl.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from einops import rearrange
|
||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.models.qwen2_5_vl import (
|
||||
Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed,
|
||||
Qwen2_5_VisionTransformer, Qwen2_5_VLDummyInputsBuilder,
|
||||
Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLMultiModalProcessor,
|
||||
Qwen2_5_VLProcessingInfo)
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
from vllm_ascend.models.qwen2_5_vl import AscendQwen2_5_VisionRotaryEmbedding
|
||||
|
||||
|
||||
class AscendQwen2_5_VisionAttention_Without_Padding(Qwen2_5_VisionAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
projection_size: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
embed_dim,
|
||||
num_heads,
|
||||
projection_size,
|
||||
quant_config,
|
||||
prefix,
|
||||
)
|
||||
self.embed_dim = embed_dim
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
projection_size, num_heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||
x, _ = self.qkv(x)
|
||||
|
||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||
q, k, v = self.split_qkv(x)
|
||||
batch_size = q.shape[1]
|
||||
|
||||
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
|
||||
for x in (q, k, v))
|
||||
q = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
|
||||
q, k, v = [
|
||||
rearrange(x, "b s h d -> (b s) h d").contiguous()
|
||||
for x in (q, k, v)
|
||||
]
|
||||
|
||||
context_layer = torch.empty_like(q)
|
||||
|
||||
# operator requires pta version >= 2.5.1.dev20250226
|
||||
torch_npu._npu_flash_attention_unpad(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
seq_len=cu_seqlens,
|
||||
scale_value=self.hidden_size_per_attention_head**-0.5,
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
num_kv_heads=self.num_attention_heads_per_partition,
|
||||
out=context_layer)
|
||||
|
||||
context_layer = rearrange(context_layer,
|
||||
"(b s) h d -> s b (h d)",
|
||||
b=batch_size).contiguous()
|
||||
|
||||
output, _ = self.proj(context_layer)
|
||||
return output
|
||||
|
||||
|
||||
class AscendQwen2_5_VisionBlock_Without_Padding(Qwen2_5_VisionBlock):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_hidden_dim: int,
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
|
||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer,
|
||||
quant_config, prefix)
|
||||
self.attn = AscendQwen2_5_VisionAttention_Without_Padding(
|
||||
embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
|
||||
cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.attn(
|
||||
self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin)
|
||||
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class AscendQwen2_5_VisionPatchEmbed_Without_Padding(Qwen2_5_VisionPatchEmbed):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.matmul(
|
||||
self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1))
|
||||
return x
|
||||
|
||||
|
||||
class AscendQwen2_5_VisionTransformer_Without_Padding(Qwen2_5_VisionTransformer
|
||||
):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config: Qwen2_5_VLVisionConfig,
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
interleaved=False,
|
||||
) -> None:
|
||||
super().__init__(vision_config, norm_eps, quant_config, prefix)
|
||||
norm_layer = partial(RMSNorm, eps=norm_eps)
|
||||
self.interleaved = interleaved
|
||||
head_dim = self.hidden_size // self.num_heads
|
||||
self.rotary_pos_emb = AscendQwen2_5_VisionRotaryEmbedding(head_dim //
|
||||
2)
|
||||
self.patch_embed = AscendQwen2_5_VisionPatchEmbed_Without_Padding(
|
||||
patch_size=vision_config.patch_size,
|
||||
temporal_patch_size=vision_config.temporal_patch_size,
|
||||
in_channels=vision_config.in_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
)
|
||||
|
||||
act_fn = get_act_and_mul_fn(vision_config.hidden_act)
|
||||
self.blocks = nn.ModuleList([
|
||||
AscendQwen2_5_VisionBlock_Without_Padding(
|
||||
dim=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
mlp_hidden_dim=vision_config.intermediate_size,
|
||||
act_fn=act_fn,
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}")
|
||||
for layer_idx in range(vision_config.depth)
|
||||
])
|
||||
self.tp_size = parallel_state.get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
self.hidden_size, self.num_heads)
|
||||
|
||||
def cal_cos_sin(self, rotary_pos_emb):
|
||||
cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2]
|
||||
sin = rotary_pos_emb.sin()
|
||||
|
||||
if not self.interleaved:
|
||||
cos_new = torch.cat((cos, cos), dim=-1)
|
||||
sin_new = torch.cat((sin, sin), dim=-1)
|
||||
else:
|
||||
cos_new = rearrange(torch.stack((cos, cos), dim=-1),
|
||||
"... d two -> ...(d two)",
|
||||
two=2)
|
||||
sin_new = rearrange(torch.stack((sin, sin), dim=-1),
|
||||
"... d two -> ...(d two)",
|
||||
two=2)
|
||||
cos_new = cos_new.reshape(1, -1, 1,
|
||||
self.hidden_size_per_attention_head)
|
||||
sin_new = sin_new.reshape(1, -1, 1,
|
||||
self.hidden_size_per_attention_head)
|
||||
return cos_new, sin_new
|
||||
|
||||
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
pos_ids = []
|
||||
for t, h, w in grid_thw:
|
||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
||||
hpos_ids = hpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
).permute(0, 2, 1, 3).flatten()
|
||||
wpos_ids = wpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
).permute(0, 2, 1, 3).flatten()
|
||||
pos_ids.append(
|
||||
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||
pos_ids = torch.cat(pos_ids, dim=0)
|
||||
max_grid_size = grid_thw[:, 1:].max()
|
||||
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
return rotary_pos_emb
|
||||
|
||||
def get_window_index(self, grid_thw):
|
||||
window_index: list = []
|
||||
cu_window_seqlens: list = [0]
|
||||
window_index_id = 0
|
||||
vit_merger_window_size = (self.window_size //
|
||||
self.spatial_merge_size // self.patch_size)
|
||||
|
||||
for grid_t, grid_h, grid_w in grid_thw:
|
||||
llm_grid_h = grid_h // self.spatial_merge_size
|
||||
llm_grid_w = grid_w // self.spatial_merge_size
|
||||
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
|
||||
grid_t, llm_grid_h, llm_grid_w)
|
||||
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
||||
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
||||
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
||||
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
||||
index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100)
|
||||
index_padded = index_padded.reshape(grid_t, num_windows_h,
|
||||
vit_merger_window_size,
|
||||
num_windows_w,
|
||||
vit_merger_window_size)
|
||||
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
||||
grid_t, num_windows_h * num_windows_w, vit_merger_window_size,
|
||||
vit_merger_window_size)
|
||||
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
||||
index_padded = index_padded.reshape(-1)
|
||||
index_new = index_padded[index_padded != -100]
|
||||
window_index.append(index_new + window_index_id)
|
||||
cu_seqlens_tmp = seqlens.cumsum(
|
||||
0) * self.spatial_merge_unit + cu_window_seqlens[-1]
|
||||
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
||||
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
||||
window_index = torch.cat(window_index, dim=0)
|
||||
return window_index, cu_window_seqlens
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# compute cu_seqlens
|
||||
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
|
||||
grid_thw[:,
|
||||
0]).cpu().to(torch.int32)
|
||||
|
||||
# patchify
|
||||
x = self.patch_embed(x)
|
||||
|
||||
# compute position embedding
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
|
||||
# windows attention
|
||||
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
|
||||
cu_window_seqlens = torch.tensor(
|
||||
cu_window_seqlens,
|
||||
device=x.device,
|
||||
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32)
|
||||
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
||||
cu_window_seqlens = torch.diff(cu_window_seqlens).cpu().to(torch.int32)
|
||||
seq_len, _ = x.size()
|
||||
x = x.reshape(seq_len // self.spatial_merge_unit,
|
||||
self.spatial_merge_unit, -1)
|
||||
x = x[window_index, :, :]
|
||||
x = x.reshape(seq_len, -1)
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(
|
||||
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
||||
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
||||
|
||||
cos, sin = self.cal_cos_sin(rotary_pos_emb)
|
||||
|
||||
# transformers
|
||||
x = x.unsqueeze(1)
|
||||
for layer_num, blk in enumerate(self.blocks):
|
||||
if layer_num in self.fullatt_block_indexes:
|
||||
cu_seqlens_now = cu_seqlens
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
x = blk(x, cu_seqlens=cu_seqlens_now, cos=cos, sin=sin)
|
||||
|
||||
# adapter
|
||||
x = self.merger(x)
|
||||
reverse_indices = torch.argsort(window_index)
|
||||
x = x[reverse_indices, :]
|
||||
return x
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
Qwen2_5_VLMultiModalProcessor,
|
||||
info=Qwen2_5_VLProcessingInfo,
|
||||
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
|
||||
class AscendQwen2_5_VLForConditionalGeneration_Without_Padding(
|
||||
Qwen2_5_VLForConditionalGeneration):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.visual = AscendQwen2_5_VisionTransformer_Without_Padding(
|
||||
vision_config=config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
|
||||
def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]:
|
||||
|
||||
grid_thw = image_input["image_grid_thw"]
|
||||
assert grid_thw.ndim == 2
|
||||
|
||||
if image_input["type"] == "image_embeds":
|
||||
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
||||
else:
|
||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
|
||||
|
||||
# Split concatenated embeddings for each image item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||
return image_embeds.split(sizes.tolist())
|
||||
|
||||
def _process_video_input(self, video_input) -> tuple[torch.Tensor, ...]:
|
||||
|
||||
grid_thw = video_input["video_grid_thw"]
|
||||
assert grid_thw.ndim == 2
|
||||
|
||||
if video_input["type"] == "video_embeds":
|
||||
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
||||
else:
|
||||
pixel_values_videos = video_input["pixel_values_videos"].type(
|
||||
self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
|
||||
|
||||
# Split concatenated embeddings for each video item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||
return video_embeds.split(sizes.tolist())
|
||||
352
vllm_ascend/models/qwen2_vl.py
Normal file
352
vllm_ascend/models/qwen2_vl.py
Normal file
@@ -0,0 +1,352 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Adapted from vllm/model_executor/models/qwen2_vl.py
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from collections.abc import Iterable
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Set, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from einops import rearrange
|
||||
from transformers.models.qwen2_vl.configuration_qwen2_vl import \
|
||||
Qwen2VLVisionConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.model_executor.layers.activation import QuickGELU
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.qwen2_vl import (
|
||||
Qwen2VisionAttention, Qwen2VisionBlock, Qwen2VisionPatchEmbed,
|
||||
Qwen2VisionTransformer, Qwen2VLDummyInputsBuilder,
|
||||
Qwen2VLForConditionalGeneration, Qwen2VLMultiModalProcessor,
|
||||
Qwen2VLProcessingInfo)
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
MIN_PAD_SIZE = 64 # min_size to pad weight
|
||||
MAX_PAD_SIZE = 128 # max_size to pad weight
|
||||
|
||||
|
||||
class AscendQwen2VisionAttention(Qwen2VisionAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
projection_size: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
embed_dim,
|
||||
num_heads,
|
||||
projection_size,
|
||||
quant_config,
|
||||
prefix,
|
||||
)
|
||||
self.cu_seqlens = None
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
projection_size, num_heads)
|
||||
self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head
|
||||
if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE:
|
||||
self.hidden_size_per_attention_head = MAX_PAD_SIZE
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
self.cu_seqlens = cu_seqlens
|
||||
|
||||
# [s, b, c] --> [s, b, 3 * head * head_dim]
|
||||
x, _ = self.qkv(x)
|
||||
|
||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||
q, k, v = self.split_qkv(x)
|
||||
batch_size = q.shape[1]
|
||||
|
||||
q, k, v = [
|
||||
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
|
||||
]
|
||||
q = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
q, k, v = [
|
||||
rearrange(x, "b s h d -> (b s) h d").contiguous()
|
||||
for x in (q, k, v)
|
||||
]
|
||||
|
||||
context_layer = torch.empty_like(q)
|
||||
|
||||
# operator requires pta version >= 2.5.1
|
||||
torch_npu._npu_flash_attention_unpad(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
seq_len=self.cu_seqlens,
|
||||
scale_value=self.origin_hidden_size_per_attention_head**-0.5,
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
num_kv_heads=self.num_attention_heads_per_partition,
|
||||
out=context_layer)
|
||||
context_layer = rearrange(context_layer,
|
||||
"(b s) h d -> s b (h d)",
|
||||
b=batch_size).contiguous()
|
||||
|
||||
output, _ = self.proj(context_layer)
|
||||
return output
|
||||
|
||||
|
||||
class AscendQwen2VisionBlock(Qwen2VisionBlock):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float,
|
||||
act_layer: Type[nn.Module] = QuickGELU,
|
||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(dim, num_heads, mlp_ratio, act_layer, norm_layer,
|
||||
quant_config, prefix)
|
||||
self.attn = AscendQwen2VisionAttention(embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x = x + self.attn(
|
||||
self.norm1(x),
|
||||
cu_seqlens=cu_seqlens,
|
||||
cos=cos,
|
||||
sin=sin,
|
||||
)
|
||||
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class AscendQwen2VisionPatchEmbed(Qwen2VisionPatchEmbed):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.matmul(
|
||||
self.proj.weight.data.view(self.embed_dim, -1).transpose(0, 1))
|
||||
return x
|
||||
|
||||
|
||||
class AscendQwen2VisionTransformer(Qwen2VisionTransformer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config: Qwen2VLVisionConfig,
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
interleaved=False,
|
||||
) -> None:
|
||||
super().__init__(vision_config, norm_eps, quant_config, prefix)
|
||||
|
||||
self.interleaved = interleaved
|
||||
self.enable_pad = False
|
||||
self.depth = vision_config.depth
|
||||
self.hidden_size = vision_config.embed_dim
|
||||
self.num_heads = vision_config.num_heads
|
||||
self.patch_embed = AscendQwen2VisionPatchEmbed(
|
||||
patch_size=vision_config.patch_size,
|
||||
temporal_patch_size=vision_config.temporal_patch_size,
|
||||
in_channels=vision_config.in_channels,
|
||||
embed_dim=vision_config.embed_dim,
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
AscendQwen2VisionBlock(dim=self.embed_dim,
|
||||
num_heads=self.num_heads,
|
||||
mlp_ratio=vision_config.mlp_ratio,
|
||||
norm_layer=partial(nn.LayerNorm,
|
||||
eps=norm_eps),
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}")
|
||||
for layer_idx in range(vision_config.depth)
|
||||
])
|
||||
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
self.hidden_size, self.num_heads)
|
||||
|
||||
if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE:
|
||||
self.enable_pad = True
|
||||
self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head
|
||||
self.half_origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head // 2
|
||||
self.half_pad_hidden_size_per_attention_head = (
|
||||
MAX_PAD_SIZE - self.hidden_size_per_attention_head) // 2
|
||||
self.hidden_size_per_attention_head = MAX_PAD_SIZE
|
||||
|
||||
def cal_cos_sin(self, rotary_pos_emb):
|
||||
cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2]
|
||||
sin = rotary_pos_emb.sin()
|
||||
if self.enable_pad:
|
||||
cos = torch.nn.functional.pad(
|
||||
cos, (0, self.half_pad_hidden_size_per_attention_head))
|
||||
sin = torch.nn.functional.pad(
|
||||
sin, (0, self.half_pad_hidden_size_per_attention_head))
|
||||
|
||||
if not self.interleaved:
|
||||
cos_new = torch.cat((cos, cos), dim=-1)
|
||||
sin_new = torch.cat((sin, sin), dim=-1)
|
||||
else:
|
||||
cos_new = rearrange(torch.stack((cos, cos), dim=-1),
|
||||
"... d two -> ...(d two)",
|
||||
two=2)
|
||||
sin_new = rearrange(torch.stack((sin, sin), dim=-1),
|
||||
"... d two -> ...(d two)",
|
||||
two=2)
|
||||
cos_new = cos_new.reshape(1, -1, 1,
|
||||
self.hidden_size_per_attention_head)
|
||||
sin_new = sin_new.reshape(1, -1, 1,
|
||||
self.hidden_size_per_attention_head)
|
||||
return cos_new, sin_new
|
||||
|
||||
def pad_qkv_bias(self, bias):
|
||||
first_half = bias.reshape(
|
||||
-1, 3, self.origin_hidden_size_per_attention_head
|
||||
)[:, :, :self.half_origin_hidden_size_per_attention_head]
|
||||
second_half = bias.reshape(
|
||||
-1, 3, self.origin_hidden_size_per_attention_head
|
||||
)[:, :, self.half_origin_hidden_size_per_attention_head:]
|
||||
first_half_padded = torch.nn.functional.pad(
|
||||
first_half, (0, self.half_pad_hidden_size_per_attention_head))
|
||||
second_half_padded = torch.nn.functional.pad(
|
||||
second_half, (0, self.half_pad_hidden_size_per_attention_head))
|
||||
bias_padded = torch.cat([first_half_padded, second_half_padded], dim=2)
|
||||
bias_final = bias_padded.reshape(-1)
|
||||
return bias_final
|
||||
|
||||
def pad_qkv_weight(self, data):
|
||||
qkv_weight_first_half = data.reshape(
|
||||
-1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size
|
||||
)[:, :, :self.half_origin_hidden_size_per_attention_head, :]
|
||||
qkv_weight_second_half = data.reshape(
|
||||
-1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size
|
||||
)[:, :, self.half_origin_hidden_size_per_attention_head:, :]
|
||||
|
||||
qkv_weight_first_half_padded = torch.nn.functional.pad(
|
||||
qkv_weight_first_half,
|
||||
(0, 0, 0, self.half_pad_hidden_size_per_attention_head))
|
||||
qkv_weight_second_half_padded = torch.nn.functional.pad(
|
||||
qkv_weight_second_half,
|
||||
(0, 0, 0, self.half_pad_hidden_size_per_attention_head))
|
||||
qkv_weight_padded = torch.cat(
|
||||
[qkv_weight_first_half_padded, qkv_weight_second_half_padded],
|
||||
dim=2)
|
||||
qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size)
|
||||
return qkv_weight_final
|
||||
|
||||
def pad_proj_weight(self, data):
|
||||
out_weight = torch.nn.functional.pad(
|
||||
data.reshape(self.hidden_size, -1,
|
||||
self.half_origin_hidden_size_per_attention_head),
|
||||
(0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape(
|
||||
self.hidden_size, -1)
|
||||
return out_weight
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
if ("attn.proj.weight" in name) and self.enable_pad:
|
||||
param.data = self.pad_proj_weight(param.data)
|
||||
if ("attn.qkv.weight" in name) and self.enable_pad:
|
||||
param.data = self.pad_qkv_weight(param.data)
|
||||
if ("attn.qkv.bias" in name) and self.enable_pad:
|
||||
param.data = self.pad_qkv_bias(param.data)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# compute cu_seqlens and avoid cumsum to fit operator unpadFA
|
||||
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
|
||||
grid_thw[:,
|
||||
0]).cpu().to(torch.int32)
|
||||
|
||||
# patchify
|
||||
x = x.to(device=self.device, dtype=self.dtype)
|
||||
x = self.patch_embed(x)
|
||||
|
||||
# compute position embedding
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
cos, sin = self.cal_cos_sin(rotary_pos_emb)
|
||||
|
||||
x = x.unsqueeze(1)
|
||||
for blk in self.blocks:
|
||||
x = blk(x, cu_seqlens=cu_seqlens, cos=cos, sin=sin)
|
||||
|
||||
# adapter
|
||||
x = self.merger(x)
|
||||
return x
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor,
|
||||
info=Qwen2VLProcessingInfo,
|
||||
dummy_inputs=Qwen2VLDummyInputsBuilder)
|
||||
class AscendQwen2VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
self.visual = AscendQwen2VisionTransformer(
|
||||
self.config.vision_config,
|
||||
norm_eps=getattr(self.config, "rms_norm_eps", 1e-6),
|
||||
quant_config=self._maybe_ignore_quant_config(
|
||||
vllm_config.quant_config),
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
156
vllm_ascend/models/qwen3.py
Normal file
156
vllm_ascend/models/qwen3.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Qwen3Config
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Model
|
||||
from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader,
|
||||
PPMissingLayer, maybe_prefix)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant
|
||||
|
||||
|
||||
class CustomQwen3DecoderLayer(Qwen3DecoderLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
if quant_config is None:
|
||||
return
|
||||
|
||||
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
|
||||
assert isinstance(quant_config, AscendQuantConfig), \
|
||||
"Expected quant_config to be an instance of AscendQuantConfig"
|
||||
|
||||
if isinstance(self.self_attn.qkv_proj.quant_method.quant_method,
|
||||
AscendW8A8LinearMethod):
|
||||
self.input_layernorm = AddRMSNormW8A8Quant(
|
||||
config.hidden_size,
|
||||
layer=self.self_attn.qkv_proj,
|
||||
eps=config.rms_norm_eps)
|
||||
if isinstance(self.mlp.gate_up_proj.quant_method.quant_method,
|
||||
AscendW8A8LinearMethod):
|
||||
self.post_attention_layernorm = AddRMSNormW8A8Quant(
|
||||
config.hidden_size,
|
||||
layer=self.mlp.gate_up_proj,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
|
||||
ALL_DECODER_LAYER_TYPES = {
|
||||
"attention": CustomQwen3DecoderLayer,
|
||||
}
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"input_ids": 0,
|
||||
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
|
||||
# otherwise (seq_len, ).
|
||||
"positions": -1,
|
||||
"intermediate_tensors": 0,
|
||||
"inputs_embeds": 0,
|
||||
})
|
||||
class CustomQwen3Model(Qwen2Model):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
decoder_layer_type=CustomQwen3DecoderLayer)
|
||||
|
||||
|
||||
class CustomQwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
# add `CustomQwen3Model` to init self.model
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = CustomQwen3Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "lm_head"))
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
393
vllm_ascend/models/qwen3_moe.py
Normal file
393
vllm_ascend/models/qwen3_moe.py
Normal file
@@ -0,0 +1,393 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2024 The Qwen team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Adapted from vllm/model_executor/models/qwen3_moe.py
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, CompilationLevel, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
||||
get_tp_group)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.interfaces import (MixtureOfExperts,
|
||||
SupportsLoRA, SupportsPP)
|
||||
from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
|
||||
Qwen3MoeDecoderLayer,
|
||||
Qwen3MoeForCausalLM,
|
||||
Qwen3MoeMLP, Qwen3MoeModel,
|
||||
Qwen3MoeSparseMoeBlock)
|
||||
from vllm.model_executor.models.utils import (
|
||||
PPMissingLayer, extract_layer_index,
|
||||
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.ops.sequence_parallel import (MetadataForPadding,
|
||||
init_metadata_for_sp)
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
|
||||
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
if self.tp_size > config.num_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.num_experts}.")
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.num_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
self.experts = AscendFusedMoE(
|
||||
num_experts=config.num_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
)
|
||||
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
self.dp_size = get_dp_group().world_size
|
||||
|
||||
self.tp_group = get_tp_group().device_group
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
self.ep_group = get_ep_group()
|
||||
|
||||
self.params_dtype = torch.get_default_dtype()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attn_metadata=None,
|
||||
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
||||
):
|
||||
if attn_metadata is None:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
# when profile runs, force experts to load balanced tokens
|
||||
# to avoid high memory consumption on a single rank.
|
||||
enable_force_load_balance = get_forward_context().in_profile_run
|
||||
is_prefill = get_forward_context().with_prefill
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
is_prefill=is_prefill,
|
||||
top_k=self.top_k,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
shared_experts=None,
|
||||
_metadata_for_padding=_metadata_for_padding,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
vllm_config: Optional[VllmConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.self_attn = Qwen3MoeAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
qkv_bias=getattr(config, 'attention_bias', False),
|
||||
head_dim=getattr(config, 'head_dim', None),
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
|
||||
# `mlp_only_layers` in the config.
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
|
||||
config.mlp_only_layers)
|
||||
self.use_aclgraph = (vllm_config is not None
|
||||
and vllm_config.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE
|
||||
and not vllm_config.model_config.enforce_eager)
|
||||
if (layer_idx not in mlp_only_layers) and (
|
||||
config.num_experts > 0 and
|
||||
(layer_idx + 1) % config.decoder_sparse_step == 0):
|
||||
if not self.use_aclgraph:
|
||||
# FIXME: custom sparse moe block doesn't work with aclgraph.
|
||||
self.mlp = CustomSparseMoeBlock(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
self.enable_sequence_parallelism = (
|
||||
vllm_config.compilation_config.pass_config.
|
||||
enable_sequence_parallelism if vllm_config is not None else False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# To prevent precision issues during the decoder phase when only prefilling enables SP
|
||||
if not self.enable_sequence_parallelism:
|
||||
self.self_attn.o_proj.reduce_results = True
|
||||
else:
|
||||
self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True
|
||||
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
residual = _metadata_for_padding.padding_slice(residual)
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
|
||||
hidden_states)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(
|
||||
hidden_states)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
if not self.use_aclgraph:
|
||||
hidden_states = self.mlp(
|
||||
hidden_states, _metadata_for_padding=_metadata_for_padding)
|
||||
else:
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class CustomQwen3MoeModel(Qwen3MoeModel):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
self.num_redundant_experts = parallel_config.num_redundant_experts
|
||||
else:
|
||||
eplb_config = parallel_config.eplb_config
|
||||
self.num_redundant_experts = eplb_config.num_redundant_experts
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.config = config
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=f"{prefix}.embed_tokens")
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: CustomQwen3MoeDecoderLayer(
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
_metadata_for_padding=_metadata_for_padding)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
|
||||
hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
SupportsPP.__init__(self)
|
||||
SupportsLoRA.__init__(self)
|
||||
MixtureOfExperts.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = CustomQwen3MoeModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
|
||||
# Set MoE hyperparameters
|
||||
self.expert_weights: list[torch.Tensor] = []
|
||||
|
||||
self.moe_layers: list[FusedMoE] = []
|
||||
example_layer = None
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer, PPMissingLayer):
|
||||
continue
|
||||
|
||||
assert isinstance(layer, Qwen3MoeDecoderLayer)
|
||||
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
|
||||
example_layer = layer.mlp
|
||||
self.moe_layers.append(layer.mlp.experts)
|
||||
|
||||
if example_layer is None:
|
||||
raise RuntimeError("No Qwen3MoE layer found in the model.layers.")
|
||||
|
||||
self.num_moe_layers = len(self.moe_layers)
|
||||
self.num_expert_groups = 1
|
||||
self.num_shared_experts = 0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
_metadata_for_padding = init_metadata_for_sp(
|
||||
input_ids, self.enable_sequence_parallelism)
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds, _metadata_for_padding)
|
||||
return hidden_states
|
||||
0
vllm_ascend/multistream/__init__.py
Normal file
0
vllm_ascend/multistream/__init__.py
Normal file
29
vllm_ascend/multistream/base.py
Normal file
29
vllm_ascend/multistream/base.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class MSEventKey(Enum):
|
||||
ATTN_COM_FINISH = 0
|
||||
ATTN_AR_FINISH = 1
|
||||
FFN_COM_FINISH = 2
|
||||
FFN_AR_FINISH = 3
|
||||
# events for MOE dispatch and combine
|
||||
MOE_BEFORE_COMM = 4
|
||||
MOE_AFTER_COMM = 5
|
||||
# events for shared expert
|
||||
MOE_SE_COMM_FINISH = 6
|
||||
MOE_SE_COMP_FINISH = 7
|
||||
MOE_GATE_FINISH = 8
|
||||
|
||||
|
||||
@dataclass
|
||||
class MSAttentionMetadataSplitConfig:
|
||||
"""
|
||||
micro batch split config for split attention metadata
|
||||
"""
|
||||
# micro batch num
|
||||
num_micro_batches: int = 2
|
||||
# split micro batches only when total tokens >= min_total_tokens_to_split
|
||||
min_total_tokens_to_split: int = 256
|
||||
# split micro batches only when prefill tokens >= min_prefill_tokens_to_split
|
||||
min_prefill_tokens_to_split: int = 64
|
||||
67
vllm_ascend/multistream/context.py
Normal file
67
vllm_ascend/multistream/context.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
_ms_comm_context: Any = None
|
||||
_cur_micro_batch_num: int = -1
|
||||
_ms_layer_index_context: int = -1
|
||||
_ms_metadata_context: Any = None
|
||||
_ms_attn_metadata_context: Any = None
|
||||
|
||||
|
||||
def set_multistream_layer_context(start_layer: int, ms_metadata: Any,
|
||||
attn_metadata: Any):
|
||||
"""
|
||||
set multistream layer context before transformer layers
|
||||
"""
|
||||
global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
|
||||
_ms_layer_index_context = start_layer
|
||||
_ms_metadata_context = ms_metadata
|
||||
_ms_attn_metadata_context = attn_metadata
|
||||
|
||||
|
||||
def reset_multistream_layer_context():
|
||||
"""
|
||||
reset multistream layer context
|
||||
"""
|
||||
global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
|
||||
_ms_layer_index_context = -1
|
||||
_ms_metadata_context = None
|
||||
_ms_attn_metadata_context = None
|
||||
|
||||
|
||||
def get_multistream_layer_context():
|
||||
"""
|
||||
get multistream layer context
|
||||
"""
|
||||
return _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
|
||||
|
||||
|
||||
def advance_step_multistream_layer_context():
|
||||
"""
|
||||
advance multistream layer index context
|
||||
"""
|
||||
global _ms_layer_index_context
|
||||
_ms_layer_index_context += 1
|
||||
|
||||
|
||||
def get_multistream_comm_context() -> Any:
|
||||
"""Get the current comm forward context."""
|
||||
return _ms_comm_context
|
||||
|
||||
|
||||
def get_multistream_microbatch_context() -> int:
|
||||
return _cur_micro_batch_num
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_multistream_context(context: Any, micro_batch_num: int):
|
||||
"""A context manager that stores the current comm forward context,
|
||||
can be attention metadata, etc."""
|
||||
global _ms_comm_context, _cur_micro_batch_num
|
||||
_ms_comm_context = context
|
||||
_cur_micro_batch_num = micro_batch_num
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_ms_comm_context = None
|
||||
_cur_micro_batch_num = -1
|
||||
22
vllm_ascend/multistream/decorator.py
Normal file
22
vllm_ascend/multistream/decorator.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from .context import (get_multistream_layer_context,
|
||||
get_multistream_microbatch_context)
|
||||
|
||||
|
||||
# vllm v1 use get_forward_context to get the attn_metadata,
|
||||
# we can use this decorator to update the attn metadata
|
||||
def set_multistream_support():
|
||||
|
||||
def decorator(func):
|
||||
|
||||
def wrapper():
|
||||
context = func()
|
||||
layer_index, ms_metadata, attn_metadata = get_multistream_layer_context(
|
||||
)
|
||||
micro_batch_num = get_multistream_microbatch_context()
|
||||
if layer_index != -1 and micro_batch_num != -1:
|
||||
context.attn_metadata = attn_metadata[micro_batch_num]
|
||||
return context
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
61
vllm_ascend/multistream/layers.py
Normal file
61
vllm_ascend/multistream/layers.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from .base import MSEventKey
|
||||
from .context import (get_multistream_layer_context,
|
||||
reset_multistream_layer_context,
|
||||
set_multistream_layer_context)
|
||||
from .metadata import MultiStreamMetadata
|
||||
|
||||
|
||||
class MultiStreamPreTransformerLayer(torch.nn.Module):
|
||||
|
||||
def __init__(self, multistream_metadata: MultiStreamMetadata):
|
||||
super().__init__()
|
||||
self.multistream_metadata = multistream_metadata
|
||||
|
||||
def forward(
|
||||
self,
|
||||
intput_tensors: List[torch.Tensor],
|
||||
):
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if self.multistream_metadata is None or attn_metadata is None:
|
||||
set_multistream_layer_context(-1, None, None)
|
||||
return attn_metadata, intput_tensors
|
||||
# TODO add attn_metadata management
|
||||
do_ms, attn_metadata, intput_tensors, _ = self.multistream_metadata.split_micro_batch(
|
||||
attn_metadata, intput_tensors)
|
||||
if do_ms:
|
||||
set_multistream_layer_context(
|
||||
self.multistream_metadata.start_layer,
|
||||
self.multistream_metadata, attn_metadata)
|
||||
else:
|
||||
set_multistream_layer_context(-1, None, None)
|
||||
return attn_metadata, intput_tensors
|
||||
|
||||
|
||||
class MultiStreamPostTransformerLayer(torch.nn.Module):
|
||||
|
||||
def __init__(self, multistream_metadata: MultiStreamMetadata):
|
||||
super().__init__()
|
||||
self.multistream_metadata = multistream_metadata
|
||||
|
||||
def forward(self,
|
||||
input_tensors: Union[List[Tuple[torch.Tensor]],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]]],
|
||||
wait_layer_index: Optional[int] = None):
|
||||
if self.multistream_metadata is None or self.multistream_metadata.ms_config is None:
|
||||
return input_tensors
|
||||
layer_index, ms_metadata, ms_attn_metadata = get_multistream_layer_context(
|
||||
)
|
||||
if layer_index >= 0:
|
||||
true_wait_layer = self.multistream_metadata.end_layer - 1 if wait_layer_index is None else wait_layer_index
|
||||
self.multistream_metadata.try_wait_event(
|
||||
true_wait_layer,
|
||||
self.multistream_metadata.ms_config.num_micro_batches - 1,
|
||||
MSEventKey.FFN_AR_FINISH)
|
||||
reset_multistream_layer_context()
|
||||
return self.multistream_metadata.merge_micro_batches(input_tensors)
|
||||
182
vllm_ascend/multistream/metadata.py
Normal file
182
vllm_ascend/multistream/metadata.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
|
||||
|
||||
from .base import MSAttentionMetadataSplitConfig, MSEventKey
|
||||
|
||||
|
||||
def split_micro_batches_tensors(input_tensors,
|
||||
split_index: int,
|
||||
keys: Optional[List[str]] = None):
|
||||
if isinstance(input_tensors, list):
|
||||
micro_batches = []
|
||||
for tensor in input_tensors:
|
||||
if tensor is None:
|
||||
micro_batches.append([None, None])
|
||||
else:
|
||||
micro_batches.append(
|
||||
[tensor[:split_index], tensor[split_index:]])
|
||||
return micro_batches
|
||||
elif isinstance(input_tensors, torch.Tensor):
|
||||
return [input_tensors[:split_index], input_tensors[split_index:]]
|
||||
elif input_tensors is None:
|
||||
return [None, None]
|
||||
elif isinstance(input_tensors, Dict):
|
||||
assert keys is not None
|
||||
micro_batches_pre = {}
|
||||
for key in keys:
|
||||
micro_batches_pre[key] = input_tensors[key][:split_index]
|
||||
micro_batches_post = {}
|
||||
for key in keys:
|
||||
micro_batches_post[key] = input_tensors[key][split_index:]
|
||||
return [micro_batches_pre, micro_batches_post]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiStreamStepMetadata:
|
||||
comm_stream: torch.npu.Stream = None
|
||||
before_comm_event: torch.npu.Event = None
|
||||
after_comm_event: torch.npu.Event = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiStreamConfig:
|
||||
"""Controls the behavior of multi-stream models."""
|
||||
min_total_tokens_to_split: int = 256
|
||||
min_prefill_tokens_to_split: int = 64
|
||||
num_micro_batches: int = 2
|
||||
imbalance_ratio: float = 0.1
|
||||
|
||||
|
||||
class MultiStreamMetadata:
|
||||
# direct stream
|
||||
calculate_stream = None
|
||||
# delay stream
|
||||
communicate_stream = None
|
||||
# events
|
||||
ms_events: Dict[int, Dict[int, Dict[MSEventKey, torch.npu.Event]]] = {}
|
||||
# multi-stream-flag
|
||||
enable_multi_stream: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
calculate_stream: torch.npu.Stream,
|
||||
communicate_stream: torch.npu.Stream,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
event_keys: List[MSEventKey],
|
||||
multistream_config: Optional[MultiStreamConfig],
|
||||
causal_lm: bool = True,
|
||||
):
|
||||
self.calculate_stream = calculate_stream
|
||||
self.communicate_stream = communicate_stream
|
||||
self.start_layer = start_layer
|
||||
self.end_layer = end_layer
|
||||
self.ms_config = multistream_config
|
||||
self.causal_lm = causal_lm
|
||||
self._build_events(event_keys)
|
||||
self._build_ms_split_config()
|
||||
|
||||
def _build_events(self, event_keys):
|
||||
if self.ms_config is not None:
|
||||
for i in range(self.start_layer - 1, self.end_layer):
|
||||
self.ms_events[i] = {}
|
||||
for j in range(self.ms_config.num_micro_batches):
|
||||
self.ms_events[i][j] = {}
|
||||
for key in event_keys:
|
||||
self.ms_events[i][j][key] = torch.npu.Event()
|
||||
|
||||
def _build_ms_split_config(self):
|
||||
if self.ms_config is not None:
|
||||
self.ms_split_config = MSAttentionMetadataSplitConfig(
|
||||
num_micro_batches=self.ms_config.num_micro_batches,
|
||||
min_total_tokens_to_split=self.ms_config.
|
||||
min_total_tokens_to_split,
|
||||
min_prefill_tokens_to_split=self.ms_config.
|
||||
min_prefill_tokens_to_split,
|
||||
)
|
||||
|
||||
def try_wait_event(self, layer_index: int, micro_batch_index: int,
|
||||
event_key: MSEventKey):
|
||||
self.ms_events[layer_index][micro_batch_index][event_key].wait()
|
||||
|
||||
def try_record_event(self, layer_index: int, micro_batch_index: int,
|
||||
event_key: MSEventKey):
|
||||
self.ms_events[layer_index][micro_batch_index][event_key].record()
|
||||
|
||||
def split_micro_batch(
|
||||
self,
|
||||
attn_metadata: "AscendMLAMetadata",
|
||||
intput_tensors: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
intermediate_tensors_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[bool, Union[AscendMLAMetadata, List[AscendMLAMetadata]], Union[
|
||||
List[torch.Tensor], List[List[torch.Tensor]]], Union[
|
||||
IntermediateTensors, List[IntermediateTensors]]]:
|
||||
attn_metadata_list = attn_metadata.split_metadata_for_multistream(
|
||||
self.ms_split_config)
|
||||
if len(attn_metadata_list) == 1:
|
||||
return False, attn_metadata_list[
|
||||
0], intput_tensors, intermediate_tensors
|
||||
split_index = attn_metadata_list[0].slot_mapping.shape[0]
|
||||
input_tensors = split_micro_batches_tensors(intput_tensors,
|
||||
split_index)
|
||||
if intermediate_tensors is not None:
|
||||
inter_tensors_list = split_micro_batches_tensors(
|
||||
intermediate_tensors.tensors, split_index,
|
||||
intermediate_tensors_keys)
|
||||
intermediate_tensors = [
|
||||
IntermediateTensors(inter_tensors)
|
||||
for inter_tensors in inter_tensors_list
|
||||
]
|
||||
return True, attn_metadata_list, input_tensors, intermediate_tensors
|
||||
|
||||
def merge_micro_batches(
|
||||
self, input_tensors: Union[List[torch.Tensor],
|
||||
List[List[torch.Tensor]]]
|
||||
) -> List[torch.Tensor]:
|
||||
if input_tensors is None or isinstance(input_tensors[0], torch.Tensor):
|
||||
return input_tensors
|
||||
batch: List[Optional[torch.Tensor]] = []
|
||||
for tensors in input_tensors:
|
||||
if tensors is None or tensors[0] is None:
|
||||
batch.append(None)
|
||||
else:
|
||||
batch.append(torch.cat(tensors, dim=0))
|
||||
return batch
|
||||
|
||||
|
||||
def make_multistream_metadata_ds(
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
causal_lm: bool = True,
|
||||
multistream_config: Optional[MultiStreamConfig] = None,
|
||||
):
|
||||
if multistream_config is None:
|
||||
return None
|
||||
event_keylist = [
|
||||
MSEventKey.ATTN_COM_FINISH,
|
||||
MSEventKey.ATTN_AR_FINISH,
|
||||
MSEventKey.FFN_COM_FINISH,
|
||||
MSEventKey.FFN_AR_FINISH,
|
||||
MSEventKey.MOE_BEFORE_COMM,
|
||||
MSEventKey.MOE_AFTER_COMM,
|
||||
MSEventKey.MOE_SE_COMM_FINISH,
|
||||
MSEventKey.MOE_SE_COMP_FINISH,
|
||||
MSEventKey.MOE_GATE_FINISH,
|
||||
]
|
||||
return MultiStreamMetadata(
|
||||
calculate_stream=torch.npu.current_stream(),
|
||||
communicate_stream=torch.npu.Stream(),
|
||||
start_layer=start_layer,
|
||||
end_layer=end_layer,
|
||||
multistream_config=multistream_config,
|
||||
event_keys=event_keylist,
|
||||
causal_lm=causal_lm,
|
||||
)
|
||||
247
vllm_ascend/multistream/ms_split.py
Normal file
247
vllm_ascend/multistream/ms_split.py
Normal file
@@ -0,0 +1,247 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
|
||||
from .base import MSAttentionMetadataSplitConfig
|
||||
|
||||
|
||||
def compute_split_seq_index(
|
||||
query_lens: Optional[list[int]],
|
||||
attn_state: AscendAttentionState,
|
||||
num_tokens: int,
|
||||
imbalance_ratio: float = 0.1,
|
||||
) -> list[int]:
|
||||
if attn_state != AscendAttentionState.DecodeOnly:
|
||||
assert query_lens is not None
|
||||
total_tokens = sum(query_lens)
|
||||
# the first index in last split
|
||||
tokens, split_index = 0, 0
|
||||
for value in query_lens:
|
||||
tokens += value
|
||||
split_index += 1
|
||||
if tokens >= total_tokens // 2:
|
||||
# check the current split index
|
||||
if abs(tokens -
|
||||
total_tokens // 2) < total_tokens * imbalance_ratio:
|
||||
return [tokens, split_index]
|
||||
# check the previous split index
|
||||
elif abs(tokens - total_tokens // 2 -
|
||||
value) < total_tokens * imbalance_ratio:
|
||||
return [tokens - value, split_index - 1]
|
||||
# fail to split if it is imbalanced
|
||||
# TODO: split tokens in seq
|
||||
else:
|
||||
return [0, 0]
|
||||
else:
|
||||
tokens = num_tokens // 2
|
||||
return [tokens, tokens]
|
||||
return [0, 0]
|
||||
|
||||
|
||||
def split_attn_tensor_type(
|
||||
input_tensor: torch.Tensor,
|
||||
index: int,
|
||||
) -> List[torch.Tensor]:
|
||||
return [input_tensor[:index], input_tensor[index:]]
|
||||
|
||||
|
||||
def split_attn_int_type(
|
||||
var: int,
|
||||
index: int,
|
||||
) -> List[torch.Tensor]:
|
||||
return [min(var, index), max(var - index, 0)]
|
||||
|
||||
|
||||
def model_input_split_v1_mla_attn(
|
||||
attn_metadata,
|
||||
_metadata_cls,
|
||||
ms_split_config: MSAttentionMetadataSplitConfig,
|
||||
) -> List[Any]:
|
||||
assert 0 < ms_split_config.num_micro_batches < 3
|
||||
if attn_metadata is None:
|
||||
return [attn_metadata]
|
||||
[token_index,
|
||||
seq_index] = compute_split_seq_index(attn_metadata.query_lens,
|
||||
attn_metadata.attn_state,
|
||||
attn_metadata.num_decode_tokens)
|
||||
if token_index == 0 or seq_index == 0 or seq_index == len(
|
||||
attn_metadata.query_lens):
|
||||
return [attn_metadata]
|
||||
|
||||
query_start_loc_cpu = np.zeros(shape=(len(attn_metadata.query_lens) + 1, ),
|
||||
dtype=int)
|
||||
np.cumsum(attn_metadata.query_lens, out=query_start_loc_cpu[1:])
|
||||
if attn_metadata.num_prefills > 0:
|
||||
prefill_query_start_loc = np.zeros(
|
||||
shape=(len(attn_metadata.prefill.query_lens) + 1, ), dtype=int)
|
||||
np.cumsum(attn_metadata.prefill.query_lens,
|
||||
out=prefill_query_start_loc[1:])
|
||||
|
||||
# split attn metadata
|
||||
[slot_mapping_pre,
|
||||
slot_mapping_post] = split_attn_tensor_type(attn_metadata.slot_mapping,
|
||||
token_index)
|
||||
[num_decodes_pre,
|
||||
num_decodes_post] = split_attn_int_type(attn_metadata.num_decodes,
|
||||
seq_index)
|
||||
[num_decode_tokens_pre, num_decode_tokens_post
|
||||
] = split_attn_int_type(attn_metadata.num_decode_tokens, token_index)
|
||||
[num_prefills_pre, num_prefills_post
|
||||
] = split_attn_int_type(attn_metadata.num_prefills,
|
||||
max(0, seq_index - attn_metadata.num_decodes))
|
||||
seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills > 0 else attn_metadata.decode.seq_lens
|
||||
[seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
|
||||
|
||||
query_start_loc_pre = query_start_loc_post = None
|
||||
if attn_metadata.query_start_loc is not None:
|
||||
query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1]
|
||||
query_start_loc_post = deepcopy(
|
||||
attn_metadata.query_start_loc[seq_index:]
|
||||
) - attn_metadata.query_start_loc[seq_index]
|
||||
[block_table_pre,
|
||||
block_table_post] = split_attn_tensor_type(attn_metadata.block_tables,
|
||||
seq_index)
|
||||
assert attn_metadata.attn_mask is not None
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
||||
# the attn_mla kernel in torch npu only accept 128*128 attn mask
|
||||
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
|
||||
attn_state_pre = attn_state_post = attn_metadata.attn_state
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
# should be none in decode only state
|
||||
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
|
||||
attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly
|
||||
else:
|
||||
# chunked prefill
|
||||
if num_prefills_pre > 0:
|
||||
attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill
|
||||
attn_mask_pre = attn_metadata.attn_mask[:token_index, :max(
|
||||
seq_lens_pre)].contiguous()
|
||||
attn_state_post = AscendAttentionState.ChunkedPrefill
|
||||
attn_mask_post = attn_metadata.attn_mask[
|
||||
token_index:, :max(seq_lens_post)].contiguous()
|
||||
else:
|
||||
attn_state_pre = AscendAttentionState.DecodeOnly
|
||||
attn_mask_pre = None
|
||||
attn_state_post = AscendAttentionState.ChunkedPrefill
|
||||
attn_mask_post = attn_metadata.attn_mask[
|
||||
token_index:, :max(seq_lens_post)].contiguous()
|
||||
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
|
||||
AscendMLAPrefillMetadata)
|
||||
if num_prefills_pre > 0:
|
||||
# split metadata.prefill
|
||||
[input_positions_pre, input_positions_post] = split_attn_tensor_type(
|
||||
attn_metadata.prefill.input_positions,
|
||||
token_index - attn_metadata.num_decode_tokens)
|
||||
[block_tables_pre, block_tables_post
|
||||
] = split_attn_tensor_type(attn_metadata.prefill.block_table,
|
||||
seq_index - attn_metadata.num_decodes)
|
||||
[prefill_query_lens_pre, prefill_query_lens_post
|
||||
] = split_attn_tensor_type(attn_metadata.prefill.query_lens,
|
||||
seq_index - attn_metadata.num_decodes)
|
||||
prefill_query_start_loc_pre = attn_metadata.prefill.query_start_loc[:
|
||||
seq_index
|
||||
+
|
||||
1 -
|
||||
attn_metadata
|
||||
.
|
||||
num_decodes]
|
||||
prefill_query_start_loc_post = deepcopy(
|
||||
attn_metadata.prefill.query_start_loc[seq_index -
|
||||
attn_metadata.num_decodes:]
|
||||
) - attn_metadata.prefill.query_start_loc[seq_index -
|
||||
attn_metadata.num_decodes]
|
||||
context_len_pre = seq_lens_pre[attn_metadata.num_decodes:]
|
||||
context_len_post = seq_lens_post
|
||||
prefill_max_query_len_pre = max(prefill_query_lens_pre)
|
||||
prefill_max_query_len_post = max(prefill_query_lens_post)
|
||||
prefill_pre = AscendMLAPrefillMetadata(
|
||||
attn_mask=attn_mask_pre,
|
||||
query_lens=prefill_query_lens_pre,
|
||||
seq_lens=seq_lens_pre,
|
||||
query_start_loc=prefill_query_start_loc_pre,
|
||||
input_positions=input_positions_pre,
|
||||
context_lens=context_len_pre,
|
||||
block_table=block_tables_pre,
|
||||
max_query_len=prefill_max_query_len_pre,
|
||||
max_seq_lens=context_len_pre.max().item(),
|
||||
)
|
||||
prefill_post = AscendMLAPrefillMetadata(
|
||||
attn_mask=attn_mask_post,
|
||||
query_lens=prefill_query_lens_post,
|
||||
seq_lens=seq_lens_post,
|
||||
query_start_loc=prefill_query_start_loc_post,
|
||||
input_positions=input_positions_post,
|
||||
context_lens=context_len_post,
|
||||
block_table=block_tables_post,
|
||||
max_query_len=prefill_max_query_len_post,
|
||||
max_seq_lens=context_len_post.max().item(),
|
||||
)
|
||||
decode_pre = attn_metadata.decode
|
||||
decode_post = None
|
||||
else:
|
||||
# prefill is None, split metadata.decode
|
||||
[input_positions_pre, input_positions_post
|
||||
] = split_attn_tensor_type(attn_metadata.decode.input_positions,
|
||||
token_index)
|
||||
[block_tables_pre, block_tables_post
|
||||
] = split_attn_tensor_type(attn_metadata.decode.block_table,
|
||||
seq_index)
|
||||
[decode_seq_lens_pre,
|
||||
decode_seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
|
||||
decode_pre = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions_pre,
|
||||
block_table=block_tables_pre,
|
||||
seq_lens=decode_seq_lens_pre,
|
||||
max_seq_lens=max(decode_seq_lens_pre),
|
||||
seq_lens_list=decode_seq_lens_pre.tolist(),
|
||||
)
|
||||
decode_post = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions_post,
|
||||
block_table=block_tables_post,
|
||||
seq_lens=decode_seq_lens_post,
|
||||
max_seq_lens=max(decode_seq_lens_post),
|
||||
seq_lens_list=decode_seq_lens_post.tolist(),
|
||||
)
|
||||
prefill_pre = None
|
||||
prefill_post = attn_metadata.prefill
|
||||
# construct metadata
|
||||
from vllm_ascend.attention.mla_v1 import AscendMLAPrefillMetadata
|
||||
attention_metadata_pre = _metadata_cls(
|
||||
num_actual_tokens=token_index,
|
||||
num_input_tokens=token_index,
|
||||
head_dim=attn_metadata.head_dim,
|
||||
slot_mapping=slot_mapping_pre,
|
||||
seq_lens=seq_lens_pre,
|
||||
query_start_loc=query_start_loc_pre,
|
||||
block_tables=block_table_pre,
|
||||
num_decodes=num_decodes_pre,
|
||||
num_prefills=num_prefills_pre,
|
||||
num_decode_tokens=num_decode_tokens_pre,
|
||||
attn_state=attn_state_pre,
|
||||
attn_mask=attn_mask_pre,
|
||||
prefill=prefill_pre,
|
||||
decode=decode_pre,
|
||||
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
|
||||
)
|
||||
attention_metadata_post = _metadata_cls(
|
||||
num_actual_tokens=attn_metadata.num_actual_tokens - token_index,
|
||||
num_input_tokens=attn_metadata.num_input_tokens - token_index,
|
||||
head_dim=attn_metadata.head_dim,
|
||||
slot_mapping=slot_mapping_post,
|
||||
seq_lens=seq_lens_post,
|
||||
query_start_loc=query_start_loc_post,
|
||||
block_tables=block_table_post,
|
||||
num_decodes=num_decodes_post,
|
||||
num_prefills=num_prefills_post,
|
||||
num_decode_tokens=num_decode_tokens_post,
|
||||
attn_mask=attn_mask_post,
|
||||
attn_state=attn_state_post,
|
||||
prefill=prefill_post,
|
||||
decode=decode_post,
|
||||
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
|
||||
)
|
||||
return [attention_metadata_pre, attention_metadata_post]
|
||||
56
vllm_ascend/ops/__init__.py
Normal file
56
vllm_ascend/ops/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
|
||||
import vllm_ascend.ops.common_fused_moe # noqa
|
||||
import vllm_ascend.ops.fused_moe # noqa
|
||||
import vllm_ascend.ops.layernorm # noqa
|
||||
import vllm_ascend.ops.vocab_parallel_embedding # noqa
|
||||
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
|
||||
from vllm_ascend.ops.rotary_embedding import (
|
||||
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding)
|
||||
|
||||
|
||||
class dummyFusionOp:
|
||||
default = None
|
||||
|
||||
def __init__(self, name=""):
|
||||
self.name = name
|
||||
|
||||
|
||||
def register_dummy_fusion_op() -> None:
|
||||
torch.ops._C.rms_norm = dummyFusionOp(name="rms_norm")
|
||||
torch.ops._C.fused_add_rms_norm = dummyFusionOp(name="fused_add_rms_norm")
|
||||
torch.ops._C.static_scaled_fp8_quant = dummyFusionOp(
|
||||
name="static_scaled_fp8_quant")
|
||||
torch.ops._C.dynamic_scaled_fp8_quant = dummyFusionOp(
|
||||
name="dynamic_scaled_fp8_quant")
|
||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant = dummyFusionOp(
|
||||
name="dynamic_per_token_scaled_fp8_quant")
|
||||
torch.ops._C.rms_norm_static_fp8_quant = dummyFusionOp(
|
||||
name="rms_norm_static_fp8_quant")
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant = dummyFusionOp(
|
||||
name="fused_add_rms_norm_static_fp8_quant")
|
||||
torch.ops._C.rms_norm_dynamic_per_token_quant = dummyFusionOp(
|
||||
name="rms_norm_dynamic_per_token_quant")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AscendQuickGELU", "AscendSiluAndMul", "AscendRotaryEmbedding",
|
||||
"AscendDeepseekScalingRotaryEmbedding"
|
||||
]
|
||||
42
vllm_ascend/ops/activation.py
Normal file
42
vllm_ascend/ops/activation.py
Normal file
@@ -0,0 +1,42 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
|
||||
|
||||
|
||||
class AscendQuickGELU(QuickGELU):
|
||||
|
||||
def forward_oot(self, x: torch.tensor) -> torch.Tensor:
|
||||
import torch_npu
|
||||
|
||||
out = torch_npu.npu_fast_gelu(x)
|
||||
return out
|
||||
|
||||
|
||||
class AscendSiluAndMul(SiluAndMul):
|
||||
|
||||
def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
|
||||
if is_310p():
|
||||
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
|
||||
else:
|
||||
out = torch_npu.npu_swiglu(x)
|
||||
return out
|
||||
309
vllm_ascend/ops/attention.py
Normal file
309
vllm_ascend/ops/attention.py
Normal file
@@ -0,0 +1,309 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/kernels/test_moe.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
|
||||
|
||||
# Implementation of vanilla chunked prefill, should be removed after the kernel is ready for
|
||||
# all the corner case
|
||||
def vanilla_chunked_prefill(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor, # (num_tokens, heads, head_size)
|
||||
key_cache: torch.Tensor, # (num_blocks, block_size, kv_heads, head_size)
|
||||
value_cache: torch.
|
||||
Tensor, # (num_blocks, block_size, kv_heads, head_size,)
|
||||
block_tables: torch.Tensor, # (num_seqs, max_num_blocks_per_seq)
|
||||
cu_seqlen_q: torch.Tensor, # (num_seqs + 1,)
|
||||
cu_seqlen_k: torch.Tensor, # (num_seqs + 1,)
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
causal: bool = True,
|
||||
) -> torch.Tensor:
|
||||
num_query_heads = query.shape[1]
|
||||
head_dim = value_cache.shape[3]
|
||||
num_kv_heads = value_cache.shape[2]
|
||||
block_size = value_cache.shape[1]
|
||||
num_batch = cu_seqlen_q.shape[0] - 1
|
||||
max_num_blocks_per_seq = block_tables.shape[1]
|
||||
|
||||
key = key_cache[block_tables].view(num_batch,
|
||||
max_num_blocks_per_seq * block_size,
|
||||
num_kv_heads, head_dim)
|
||||
|
||||
value = value_cache[block_tables].view(num_batch,
|
||||
max_num_blocks_per_seq * block_size,
|
||||
num_kv_heads, head_dim)
|
||||
key = key[:, :max_seqlen_k, :, :]
|
||||
value = value[:, :max_seqlen_k, :, :]
|
||||
|
||||
seqlen_k = cu_seqlen_k[1:] - cu_seqlen_k[:-1]
|
||||
seqlen_q = cu_seqlen_q[1:] - cu_seqlen_q[:-1]
|
||||
seqlen_q = seqlen_q.view(-1, 1)
|
||||
seqlen_k = seqlen_k.view(-1, 1)
|
||||
seqlen_diff = seqlen_k - seqlen_q
|
||||
q_idx_mask = (torch.arange(0, max_seqlen_q,
|
||||
device="npu").view(1, -1).repeat(num_batch, 1))
|
||||
k_idx_mask = (torch.arange(0, max_seqlen_k,
|
||||
device="npu").view(1, -1).repeat(num_batch, 1))
|
||||
q_mask = q_idx_mask < seqlen_q
|
||||
k_mask = k_idx_mask < seqlen_k
|
||||
|
||||
# calculate idx for causal mask of query [batch, max_seqlen_q]
|
||||
causal_mask_idx = (q_idx_mask + seqlen_diff)[q_mask]
|
||||
|
||||
# generate causal mask [batch, max_seqlen_q, max_seqlen_k]
|
||||
tril_mask = torch.tril(torch.ones(max_seqlen_k, max_seqlen_k,
|
||||
device="npu"))
|
||||
tril_mask[tril_mask == 0] = float("-inf")
|
||||
tril_mask[tril_mask == 1] = 0
|
||||
causal_mask = tril_mask[causal_mask_idx]
|
||||
causal_mask_padding = torch.empty([num_batch, max_seqlen_q, max_seqlen_k],
|
||||
device="npu").fill_(float("-inf"))
|
||||
causal_mask_padding[q_mask] = causal_mask
|
||||
# to [batch, num_heads, max_seqlen_q, max_seqlen_k]
|
||||
causal_mask_padding = causal_mask_padding.unsqueeze(1)
|
||||
|
||||
pad_q = torch.zeros(
|
||||
[num_batch, max_seqlen_q, num_query_heads, head_dim],
|
||||
device="npu",
|
||||
dtype=query.dtype,
|
||||
)
|
||||
pad_k = torch.zeros(
|
||||
[num_batch, max_seqlen_k, num_kv_heads, head_dim],
|
||||
device="npu",
|
||||
dtype=key.dtype,
|
||||
)
|
||||
pad_v = torch.zeros(
|
||||
[num_batch, max_seqlen_k, num_kv_heads, head_dim],
|
||||
device="npu",
|
||||
dtype=value.dtype,
|
||||
)
|
||||
pad_q[q_mask] = query
|
||||
pad_k[k_mask] = key[k_mask]
|
||||
pad_v[k_mask] = value[k_mask]
|
||||
|
||||
if num_query_heads > num_kv_heads:
|
||||
pad_k = pad_k.view(
|
||||
[num_batch, max_seqlen_k, num_kv_heads, 1, head_dim])
|
||||
pad_k = pad_k.repeat(1, 1, 1, num_query_heads // num_kv_heads, 1).view(
|
||||
[num_batch, max_seqlen_k, num_query_heads, head_dim])
|
||||
pad_v = pad_v.view(
|
||||
[num_batch, max_seqlen_k, num_kv_heads, 1, head_dim])
|
||||
pad_v = pad_v.repeat(1, 1, 1, num_query_heads // num_kv_heads, 1).view(
|
||||
[num_batch, max_seqlen_k, num_query_heads, head_dim])
|
||||
# permute to [b, h, n, k]
|
||||
pad_q = pad_q.permute(0, 2, 1, 3)
|
||||
pad_k = pad_k.permute(0, 2, 1, 3)
|
||||
pad_v = pad_v.permute(0, 2, 1, 3)
|
||||
attn_mask = torch.empty([num_batch, 1, 1, max_seqlen_k],
|
||||
device="npu").fill_(float("-inf"))
|
||||
attn_mask[:, :, :, :max_seqlen_k].masked_fill_(k_mask[:, None, None, :], 0)
|
||||
# [b, h, f, t]
|
||||
attn_weights = torch.einsum("bhqd,bhkd->bhqk", pad_q, pad_k)
|
||||
attn_weights *= scale
|
||||
attn_mask = attn_mask.float()
|
||||
attn_weights = attn_weights + attn_mask
|
||||
if causal:
|
||||
attn_weights = attn_weights + causal_mask_padding
|
||||
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1)
|
||||
attn_output = torch.einsum("bhqk,bhkd->bhqd", attn_weights, pad_v.float())
|
||||
attn_output = attn_output.permute(0, 2, 1, 3)
|
||||
|
||||
attn_output = (attn_output[q_mask].view([-1, num_query_heads,
|
||||
head_dim]).to(output.dtype))
|
||||
output.copy_(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
def vanilla_chunked_prefill_mla(
|
||||
output: torch.Tensor, # (num_tokens, num_heads, v_head_dim)
|
||||
query: torch.Tensor, # (num_tokens, num_heads, nope_dim + rope_dim)
|
||||
kv_cache: Tuple[
|
||||
torch.Tensor], # [nope, rope] (num_blocks, block_size, latent_kv)
|
||||
block_tables: torch.Tensor, # (batch_size, max_num_blocks_per_seq)
|
||||
query_lens: torch.Tensor, # (batch_size)
|
||||
context_lens: torch.Tensor, # (batch_size)
|
||||
kv_b_proj: ColumnParallelLinear, # ()
|
||||
max_query_len: int,
|
||||
max_context_len: int,
|
||||
nope_dim: int,
|
||||
rope_dim: int,
|
||||
v_head_dim: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
causal: bool = True) -> None:
|
||||
batch_size = block_tables.size(0)
|
||||
assert len(kv_cache) > 1
|
||||
assert query_lens.size(0) == batch_size
|
||||
num_heads = query.size(1)
|
||||
nope_cache = kv_cache[0]
|
||||
rope_cache = kv_cache[1]
|
||||
block_size = nope_cache.size(1)
|
||||
latent_kv_dim = nope_cache.size(-1)
|
||||
max_num_blocks_per_seq = block_tables.size(1)
|
||||
batch_size = query_lens.size(0)
|
||||
nope_cache = nope_cache.squeeze()
|
||||
# select kv_c out as [batch_size, max_context_len, latent_kv + rope_dim] and get kv_c and k_pe
|
||||
# cached_kv_c: [batch_size, max_context_len, latent_kv]
|
||||
# cached_k_pe: [batch_size, max_context_len, rope_dim]
|
||||
cache_kv_c = nope_cache[block_tables].view(
|
||||
batch_size, max_num_blocks_per_seq * block_size,
|
||||
latent_kv_dim)[:, :max_context_len, :]
|
||||
cache_k_pe = rope_cache[block_tables].view(
|
||||
batch_size, max_num_blocks_per_seq * block_size,
|
||||
rope_dim)[:, :max_context_len, :]
|
||||
# get k_rope and v
|
||||
# k_nope: [batch_size, max_context_len, num_heads, nope_dim]
|
||||
# value: [batch_size, max_context_len, num_heads, v_head_dim]
|
||||
k_nope, value = kv_b_proj(cache_kv_c)[0].view(
|
||||
batch_size, max_context_len, num_heads,
|
||||
nope_dim + v_head_dim).split([nope_dim, v_head_dim], dim=-1)
|
||||
# key: [batch_size, max_context_len, num_hads, rope_dim + nope_dim]
|
||||
key = torch.cat(
|
||||
[k_nope, cache_k_pe.unsqueeze(2).expand(-1, -1, num_heads, -1)],
|
||||
dim=-1)
|
||||
|
||||
context_lens = context_lens.view(-1, 1).to("npu")
|
||||
query_lens = query_lens.view(-1, 1).to("npu")
|
||||
seq_diff = context_lens - query_lens
|
||||
|
||||
q_idx_mask = (torch.arange(0, max_query_len,
|
||||
device="npu").view(1, -1).repeat(batch_size, 1))
|
||||
kv_c_idx_mask = (torch.arange(0, max_context_len,
|
||||
device="npu").view(1,
|
||||
-1).repeat(batch_size, 1))
|
||||
kv_c_mask = kv_c_idx_mask < context_lens
|
||||
q_mask = q_idx_mask < query_lens
|
||||
|
||||
# calculate idx for causal mask of query [batch, max_seqlen_q]
|
||||
causal_mask_idx = (q_idx_mask + seq_diff)[q_mask]
|
||||
|
||||
# generate causal mask [batch, max_seqlen_q, max_seqlen_k]
|
||||
tril_mask = torch.tril(
|
||||
torch.ones(max_context_len, max_context_len, device="npu"))
|
||||
tril_mask[tril_mask == 0] = float("-inf")
|
||||
tril_mask[tril_mask == 1] = 0
|
||||
causal_mask = tril_mask[causal_mask_idx]
|
||||
causal_mask_padding = torch.empty(
|
||||
[batch_size, max_query_len, max_context_len],
|
||||
device="npu").fill_(float("-inf"))
|
||||
causal_mask_padding[q_mask] = causal_mask
|
||||
# to [batch, num_heads, max_seqlen_q, max_seqlen_k]
|
||||
causal_mask_padding = causal_mask_padding.unsqueeze(1)
|
||||
|
||||
pad_q = torch.zeros(
|
||||
[batch_size, max_query_len, num_heads, rope_dim + nope_dim],
|
||||
device="npu",
|
||||
dtype=query.dtype,
|
||||
)
|
||||
pad_k = torch.zeros(
|
||||
[batch_size, max_context_len, num_heads, rope_dim + nope_dim],
|
||||
device="npu",
|
||||
dtype=key.dtype,
|
||||
)
|
||||
pad_v = torch.zeros(
|
||||
[batch_size, max_context_len, num_heads, v_head_dim],
|
||||
device="npu",
|
||||
dtype=value.dtype,
|
||||
)
|
||||
num_query = torch.sum(q_mask).item()
|
||||
num_add_query = num_query - query.size(0)
|
||||
# mtp will come in
|
||||
if num_add_query > 0:
|
||||
add_query_size = query.size()
|
||||
add_query_size = list(add_query_size)
|
||||
add_query_size[0] = num_add_query
|
||||
pad_tensor = torch.zeros(add_query_size,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
query = torch.cat([query, pad_tensor], dim=0)
|
||||
pad_q[q_mask] = query
|
||||
pad_k[kv_c_mask] = key[kv_c_mask]
|
||||
pad_v[kv_c_mask] = value[kv_c_mask]
|
||||
|
||||
pad_q = pad_q.permute(0, 2, 1, 3)
|
||||
pad_k = pad_k.permute(0, 2, 1, 3)
|
||||
pad_v = pad_v.permute(0, 2, 1, 3)
|
||||
attn_mask = torch.empty([batch_size, 1, 1, max_context_len],
|
||||
device="npu").fill_(float("-inf"))
|
||||
attn_mask[:, :, :, :max_context_len].masked_fill_(
|
||||
kv_c_mask[:, None, None, :], 0)
|
||||
# [b, h, f, t]
|
||||
attn_weights = torch.einsum("bhqd,bhkd->bhqk", pad_q, pad_k)
|
||||
attn_weights *= scale
|
||||
attn_mask = attn_mask.float()
|
||||
attn_weights = attn_weights + attn_mask
|
||||
if causal:
|
||||
attn_weights = attn_weights + causal_mask_padding
|
||||
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1)
|
||||
attn_output = torch.einsum("bhqk,bhkd->bhqd", attn_weights, pad_v.float())
|
||||
attn_output = attn_output.permute(0, 2, 1, 3)
|
||||
|
||||
attn_output = (attn_output[q_mask].view([-1, num_heads,
|
||||
v_head_dim]).to(output.dtype))
|
||||
attn_output = attn_output.view_as(output)
|
||||
output.copy_(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
def vanilla_decode_mla(
|
||||
query: torch.Tensor, # [num_tokens, num_heads, latent_dim + rope_dim]
|
||||
key_cache: torch.
|
||||
Tensor, # [num_blocks, block_size, num_kv_heads, latent_dim + rope_dim]
|
||||
num_kv_heads: int,
|
||||
num_heads: int,
|
||||
scale: float,
|
||||
block_table: torch.Tensor, # [batch_size, max_block_size]
|
||||
context_lens: List[int],
|
||||
mla_vhead_size: int,
|
||||
rope_dim: int,
|
||||
output: torch.Tensor):
|
||||
batch_size = block_table.size()[0]
|
||||
max_block_size = block_table.size()[1]
|
||||
reduce_dim = key_cache.size()[-1]
|
||||
block_size = key_cache.size()[1]
|
||||
latent_dim = reduce_dim - rope_dim
|
||||
kv_c_and_pe = key_cache[block_table].view(
|
||||
[batch_size, max_block_size * block_size, num_kv_heads, reduce_dim])
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, device="npu").view(batch_size, 1)
|
||||
# [batch_size, max_context_len, num_kv_heads, latent_dim + rope_dim]
|
||||
# since the kv head is 1 in deepseek, we use expand here for perf
|
||||
kv_c_and_pe = kv_c_and_pe[:, :max_context_len, :, :].expand(
|
||||
-1, -1, num_heads, 1)
|
||||
kv_c = kv_c_and_pe[..., :latent_dim]
|
||||
kv_idx_mask = (torch.arange(0, max_context_len,
|
||||
device="npu").view(1,
|
||||
-1).repeat(batch_size, 1))
|
||||
# [batch_size, max_context_len]
|
||||
kv_idx_mask = kv_idx_mask < context_lens
|
||||
query = query.unsqueeze(1)
|
||||
attn_weights = torch.einsum("bqhd,bkhd->bhqk", query, kv_c_and_pe)
|
||||
attn_weights *= scale
|
||||
attn_weights = attn_weights + kv_idx_mask[:, -1, -1, :].float()
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1)
|
||||
attn_output = torch.einsum("bhqk,bkhd->bqhd", attn_weights,
|
||||
kv_c.float()).view(-1, num_heads, latent_dim)
|
||||
output.copy_(attn_output)
|
||||
return output
|
||||
62
vllm_ascend/ops/comm_utils.py
Normal file
62
vllm_ascend/ops/comm_utils.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
|
||||
COMM_STREAM = None
|
||||
|
||||
|
||||
def async_all_to_all(input_,
|
||||
output_split_sizes,
|
||||
input_split_sizes,
|
||||
group,
|
||||
event=None):
|
||||
if output_split_sizes is None:
|
||||
# Equal split (all2all)
|
||||
a2a_out = torch.empty_like(input_)
|
||||
else:
|
||||
# Unequal split (all2all-v)
|
||||
a2a_out = input_.new_empty(
|
||||
size=[sum(output_split_sizes)] + list(input_.size()[1:]),
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device(),
|
||||
)
|
||||
|
||||
if event:
|
||||
# multi stream wait event
|
||||
global COMM_STREAM
|
||||
if COMM_STREAM is None:
|
||||
COMM_STREAM = torch_npu.npu.Stream(
|
||||
device=torch.npu.current_device())
|
||||
with torch_npu.npu.stream(COMM_STREAM):
|
||||
event.wait()
|
||||
handle = dist.all_to_all_single(
|
||||
a2a_out,
|
||||
input_.contiguous(),
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=group,
|
||||
async_op=True)
|
||||
else:
|
||||
handle = dist.all_to_all_single(a2a_out,
|
||||
input_.contiguous(),
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=group,
|
||||
async_op=True)
|
||||
return input_, a2a_out, handle
|
||||
531
vllm_ascend/ops/common_fused_moe.py
Normal file
531
vllm_ascend/ops/common_fused_moe.py
Normal file
@@ -0,0 +1,531 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import CompilationLevel, get_current_vllm_config
|
||||
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import \
|
||||
FusedMoEParallelConfig # isort: skip
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, UnquantizedFusedMoEMethod)
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
|
||||
AlltoAllCommImpl,
|
||||
MC2CommImpl)
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
|
||||
setup_token_dispatchers
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, vllm_version_is
|
||||
|
||||
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
|
||||
|
||||
|
||||
def fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int4_w4a8: bool = False,
|
||||
global_num_experts: Optional[int] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
# For TorchAir graph
|
||||
is_torchair: bool = False,
|
||||
# For Cube/Vector parallel
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
# For load balance
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
) -> torch.Tensor:
|
||||
# Check constraints
|
||||
assert hidden_states.shape[1] == w1.shape[1], (
|
||||
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[1]}")
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
if (use_int8_w8a8 or use_int4_w4a8):
|
||||
assert w1_scale is not None and w2_scale is not None, \
|
||||
"INT8 quantization requires weight scales."
|
||||
|
||||
w1_scale = w1_scale.to(torch.float32)
|
||||
down_scale = [w2_scale]
|
||||
down_output_dtype = w2_scale.dtype
|
||||
else:
|
||||
down_scale = None
|
||||
down_output_dtype = None
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
assert moe_comm_method is not None, "Missing communication context"
|
||||
|
||||
num_experts = w1.shape[0]
|
||||
|
||||
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type = moe_comm_method.permute(
|
||||
hidden_states, topk_ids, topk_weights, expert_map, num_experts,
|
||||
use_int8_w8a8 or use_int4_w4a8)
|
||||
|
||||
gate_up_output = torch_npu.npu_grouped_matmul(
|
||||
x=[permuted_hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=expert_tokens,
|
||||
output_dtype=torch.int32 if use_int8_w8a8 else None,
|
||||
)[0]
|
||||
|
||||
if (use_int8_w8a8 or use_int4_w4a8):
|
||||
activated_output, activated_output_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=gate_up_output,
|
||||
weight_scale=w1_scale,
|
||||
activation_scale=dynamic_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=expert_tokens,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
activated_output_scale = [activated_output_scale]
|
||||
else:
|
||||
activated_output = torch_npu.npu_swiglu(gate_up_output)
|
||||
activated_output_scale = None
|
||||
|
||||
down_output = torch_npu.npu_grouped_matmul(
|
||||
x=[activated_output],
|
||||
weight=[w2],
|
||||
scale=down_scale,
|
||||
per_token_scale=activated_output_scale,
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=expert_tokens,
|
||||
output_dtype=down_output_dtype,
|
||||
)[0]
|
||||
|
||||
moe_comm_method.unpermute(down_output, hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def fused_experts_moge(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
||||
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
|
||||
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
|
||||
topk_weights: Routing weights of shape (num_tokens, top_k).
|
||||
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
||||
top_k: Number of experts to select.
|
||||
expert_map: Expert mapping of shape (num_experts,).
|
||||
|
||||
Returns:
|
||||
hidden_states: Hidden states after routing.
|
||||
"""
|
||||
ep_size = moe_parallel_config.ep_size
|
||||
local_num_experts = global_num_experts // ep_size
|
||||
local_num_group = top_k // ep_size
|
||||
|
||||
bsz, _ = hidden_states.shape
|
||||
flatten_topk_ids = topk_ids.view(-1)
|
||||
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
||||
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
|
||||
sorted_hidden_states = hidden_states.index_select(
|
||||
0, sorted_topk_ids // local_num_group)
|
||||
|
||||
experts_id = torch.arange(0,
|
||||
local_num_experts,
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device)
|
||||
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
||||
torch.float32).sum(0)
|
||||
topk_scales = topk_weights.view(-1).index_select(
|
||||
0, sorted_topk_ids).unsqueeze(-1)
|
||||
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
||||
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||
x=[sorted_hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
if is_310p():
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
||||
torch.float16)
|
||||
else:
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
gate_up_out *= topk_scales
|
||||
|
||||
down_out_list = torch_npu.npu_grouped_matmul(
|
||||
x=[gate_up_out],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
|
||||
unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids)
|
||||
final_hidden_states = unsorted_hidden_states.reshape(
|
||||
bsz, top_k // ep_size, -1).sum(1)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def unquantized_fused_moe_init_func(self, *args, **kwargs):
|
||||
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
|
||||
|
||||
# NOTE: Currently, this self.use_aclgraph is only used in
|
||||
# UnquantizedFusedMoEMethod.forward_oot to decide whether to use in
|
||||
# ops/fused_moe.py:568 to circumvent torch.randint_like not supported issue.
|
||||
# Once torch.randint_like is supported or removed, this flag can be removed.
|
||||
vllm_config = get_current_vllm_config()
|
||||
ascend_config = get_ascend_config()
|
||||
if ascend_config.torchair_graph_config.enabled:
|
||||
self.use_aclgraph = False
|
||||
else:
|
||||
self.use_aclgraph = (vllm_config.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE
|
||||
and not vllm_config.model_config.enforce_eager)
|
||||
|
||||
|
||||
def forward_oot_v01011(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
topk_weights, topk_ids, _ = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=1.0,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
if topk_ids.shape[1] < top_k or is_310p():
|
||||
assert global_num_experts is not None
|
||||
return fused_experts_moge(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
moe_parallel_config=self.moe.moe_parallel_config,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
topk_weights, topk_ids, _ = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
if topk_ids.shape[1] < top_k or is_310p():
|
||||
assert global_num_experts is not None
|
||||
return fused_experts_moge(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
moe_parallel_config=self.moe.moe_parallel_config,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer)
|
||||
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
|
||||
|
||||
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
||||
|
||||
if not is_310p():
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
|
||||
class AscendFusedMoE(FusedMoE):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_experts,
|
||||
top_k,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
params_dtype=None,
|
||||
reduce_results=False,
|
||||
renormalize=True,
|
||||
use_grouped_topk=False,
|
||||
num_expert_group=None,
|
||||
topk_group=None,
|
||||
quant_config=None,
|
||||
tp_size=None,
|
||||
ep_size=None,
|
||||
dp_size=None,
|
||||
prefix="",
|
||||
custom_routing_function=None,
|
||||
scoring_func="softmax",
|
||||
routed_scaling_fator: float = 1.0,
|
||||
e_score_correction_bias=None,
|
||||
apply_router_weight_on_input=False,
|
||||
activation="silu",
|
||||
enable_eplb=False,
|
||||
num_redundant_experts=0,
|
||||
has_bias=False,
|
||||
):
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
super().__init__(
|
||||
num_experts,
|
||||
top_k,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
params_dtype,
|
||||
reduce_results,
|
||||
renormalize,
|
||||
use_grouped_topk,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
quant_config,
|
||||
tp_size,
|
||||
ep_size,
|
||||
dp_size,
|
||||
prefix,
|
||||
custom_routing_function,
|
||||
scoring_func,
|
||||
e_score_correction_bias,
|
||||
apply_router_weight_on_input,
|
||||
activation,
|
||||
enable_eplb,
|
||||
num_redundant_experts,
|
||||
has_bias,
|
||||
)
|
||||
else:
|
||||
super().__init__(
|
||||
num_experts,
|
||||
top_k,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
params_dtype,
|
||||
reduce_results,
|
||||
renormalize,
|
||||
use_grouped_topk,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
quant_config,
|
||||
tp_size,
|
||||
ep_size,
|
||||
dp_size,
|
||||
prefix,
|
||||
custom_routing_function,
|
||||
scoring_func,
|
||||
routed_scaling_fator,
|
||||
e_score_correction_bias,
|
||||
apply_router_weight_on_input,
|
||||
activation,
|
||||
enable_eplb,
|
||||
num_redundant_experts,
|
||||
has_bias,
|
||||
)
|
||||
|
||||
setup_token_dispatchers(self.moe_config.ep_size,
|
||||
top_k=self.top_k,
|
||||
num_experts=self.global_num_experts,
|
||||
num_local_experts=self.local_num_experts)
|
||||
|
||||
self.moe_config.tp_group = get_tp_group()
|
||||
self.moe_config.dp_group = get_dp_group()
|
||||
self.moe_config.ep_group = get_ep_group()
|
||||
self.moe_config.mc2_group = get_mc2_group()
|
||||
|
||||
for method in {AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl}:
|
||||
setattr(
|
||||
self, method.__name__.lower(),
|
||||
method(moe_config=self.moe_config)) # type: ignore[abstract]
|
||||
|
||||
def forward_impl(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
assert self.quant_method is not None
|
||||
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_method_name = forward_context.moe_comm_method_name
|
||||
|
||||
# TODO: Can we refactor this logic to model_runner?
|
||||
# TODO: Adjusted logic to differentiate between A2 and A3, we check ep_size here since mc2 only support ep_size >= 16 on A3 now
|
||||
if self.moe_config.ep_size < 16:
|
||||
moe_comm_method_name = "allgathercommimpl"
|
||||
|
||||
forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
|
||||
|
||||
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
|
||||
hidden_states=hidden_states, router_logits=router_logits)
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
activation=self.activation,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
enable_eplb=self.enable_eplb,
|
||||
expert_load_view=self.expert_load_view,
|
||||
logical_to_physical_map=self.logical_to_physical_map,
|
||||
logical_replica_count=self.logical_replica_count,
|
||||
)
|
||||
|
||||
final_hidden_states = forward_context.moe_comm_method.finalize(
|
||||
hidden_states=final_hidden_states,
|
||||
reduce_results=self.reduce_results)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
|
||||
UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
UnquantizedFusedMoEMethod.forward_oot = forward_oot_v01011
|
||||
else:
|
||||
UnquantizedFusedMoEMethod.forward_oot = forward_oot
|
||||
99
vllm_ascend/ops/expert_load_balancer.py
Normal file
99
vllm_ascend/ops/expert_load_balancer.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import json
|
||||
import random
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class ExpertLoadBalancer(object):
|
||||
|
||||
def __init__(self, expert_map_path, global_expert_num):
|
||||
self.expert_map_path = expert_map_path
|
||||
self.global_expert_num = global_expert_num
|
||||
self.expert_map_tensor, self.layers_num, self.ranks_num = (
|
||||
self._expert_file_to_tensor())
|
||||
|
||||
def _expert_file_to_tensor(self):
|
||||
with open(self.expert_map_path, "r") as f:
|
||||
data = json.load(f)
|
||||
layers_num = data["moe_layer_count"]
|
||||
gpus_num = data["layer_list"][0]["device_count"]
|
||||
|
||||
tensor_data = []
|
||||
for layer in data["layer_list"]:
|
||||
device_data = []
|
||||
for device in layer["device_list"]:
|
||||
device_data.append(device["device_expert"])
|
||||
tensor_data.append(device_data)
|
||||
expert_map_tensor = torch.tensor(tensor_data, dtype=torch.int32)
|
||||
return expert_map_tensor, layers_num, gpus_num
|
||||
|
||||
def generate_index_dicts(self, tensor_2d):
|
||||
dict_list = []
|
||||
current_idx = 0
|
||||
|
||||
for row in tensor_2d:
|
||||
value_to_index = {}
|
||||
for i in range(row.size(0)):
|
||||
value = row[i].item()
|
||||
value_to_index[value] = current_idx + i
|
||||
dict_list.append(value_to_index)
|
||||
current_idx += row.size(0)
|
||||
|
||||
return dict_list
|
||||
|
||||
def generate_expert_placement_map(self):
|
||||
expert_placement_map = torch.full(
|
||||
(self.layers_num, self.ranks_num, self.global_expert_num),
|
||||
-1,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
for layer_id in range(self.layers_num):
|
||||
for gpu_id in range(self.ranks_num):
|
||||
e_ids = self.expert_map_tensor[layer_id, gpu_id]
|
||||
expert_placement_map[layer_id, gpu_id,
|
||||
e_ids] = torch.arange(len(e_ids),
|
||||
dtype=torch.int32)
|
||||
return expert_placement_map
|
||||
|
||||
def generate_log2phy_expert_map(self, layer_id):
|
||||
concatenated = torch.flatten(self.expert_map_tensor[layer_id])
|
||||
rank_expert_to_global = self.generate_index_dicts(
|
||||
self.expert_map_tensor[layer_id])
|
||||
result_dict: Dict[int, List[int]] = {}
|
||||
for idx, value in enumerate(concatenated):
|
||||
key = value.item()
|
||||
if key not in result_dict:
|
||||
result_dict[key] = []
|
||||
result_dict[key].append(idx)
|
||||
|
||||
log2phy_map = torch.full((self.ranks_num, self.global_expert_num),
|
||||
-1,
|
||||
dtype=torch.int32)
|
||||
for rank in range(self.ranks_num):
|
||||
for key in result_dict:
|
||||
indices_in_concat = result_dict[key]
|
||||
if key in rank_expert_to_global[rank]:
|
||||
log2phy_map[rank][key] = rank_expert_to_global[rank][key]
|
||||
else:
|
||||
chosen_index = random.choice(indices_in_concat)
|
||||
log2phy_map[rank][key] = chosen_index
|
||||
return log2phy_map
|
||||
|
||||
def get_rank_placement_map(self, layer_id, rank_id):
|
||||
expert_placement_map = self.generate_expert_placement_map()
|
||||
layer_expert_map = expert_placement_map[layer_id]
|
||||
rank_expert_map = layer_expert_map[rank_id].to(
|
||||
torch.npu.current_device())
|
||||
rank_local_expert_num = torch.sum(torch.ne(rank_expert_map, -1)).item()
|
||||
return rank_local_expert_num, rank_expert_map
|
||||
|
||||
def get_rank_log2phy_map(self, layer_id, rank_id):
|
||||
layer_log2phy_map = self.generate_log2phy_expert_map(layer_id)
|
||||
return layer_log2phy_map[rank_id]
|
||||
|
||||
def get_global_redundant_expert_num(self):
|
||||
global_redundant_expert_num = (
|
||||
len(self.expert_map_tensor[0][0]) * self.ranks_num -
|
||||
self.global_expert_num)
|
||||
return global_redundant_expert_num
|
||||
587
vllm_ascend/ops/fused_moe.py
Normal file
587
vllm_ascend/ops/fused_moe.py
Normal file
@@ -0,0 +1,587 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/kernels/test_moe.py
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
||||
get_tp_group)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import \
|
||||
FusedMoEConfig # isort: skip
|
||||
from vllm.model_executor.layers.fused_moe.config import \
|
||||
FusedMoEParallelConfig # isort: skip
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
|
||||
from vllm.model_executor.layers.quantization.base_config import \
|
||||
QuantizationConfig
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||
from vllm_ascend.distributed.communication_op import \
|
||||
data_parallel_reduce_scatter
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.ops.layers.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, dispose_tensor,
|
||||
get_all_reduce_merge_state,
|
||||
get_rm_router_logits_state, is_310p)
|
||||
|
||||
|
||||
def unified_fused_experts_eager(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w1_scale_bias: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale_bias: Optional[torch.Tensor] = None,
|
||||
shared_experts: Optional[torch.Tensor] = None,
|
||||
shared_gate_up: Optional[Any] = None,
|
||||
shared_dequant_scale: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False):
|
||||
token_dispatcher = get_forward_context().token_dispatcher
|
||||
|
||||
results = token_dispatcher.token_dispatch(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
shared_experts=shared_experts,
|
||||
shared_gate_up=shared_gate_up,
|
||||
shared_dequant_scale=shared_dequant_scale,
|
||||
mc2_mask=mc2_mask,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
with_quant=with_quant)
|
||||
|
||||
expert_output = unified_apply_mlp(
|
||||
hidden_states=results["hidden_states"],
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=results["group_list"],
|
||||
dynamic_scale=results.get("dynamic_scale"),
|
||||
group_list_type=results.get("group_list_type"),
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
topk_scales=results.get("topk_scales"),
|
||||
with_quant=with_quant)
|
||||
final_hidden_states = token_dispatcher.token_combine(expert_output)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
|
||||
def __init__(self, moe: FusedMoEConfig = None):
|
||||
|
||||
super().__init__(moe=moe)
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
get_ascend_config()
|
||||
|
||||
try:
|
||||
device_group = get_mc2_group().device_group
|
||||
# TODO: Try local_rank = ep_group.rank_in_group
|
||||
local_rank = torch.distributed.get_rank(group=device_group)
|
||||
backend = device_group._get_backend(torch.device("npu"))
|
||||
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
|
||||
local_rank)
|
||||
except AttributeError:
|
||||
self.moe_all_to_all_group_name = None
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
super(UnquantizedFusedMoEMethod,
|
||||
self).process_weights_after_loading(layer)
|
||||
layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight(
|
||||
layer.w13_weight.data),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
|
||||
layer.w2_weight.data),
|
||||
requires_grad=False)
|
||||
if not is_310p():
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = False,
|
||||
enable_force_load_balance: bool = False,
|
||||
shared_experts: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
topk_weights, topk_ids, row_idx = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
# this is a naive implementation for experts load balance so as
|
||||
# to avoid accumulating too much tokens on a single rank.
|
||||
# currently it is only activated when doing profile runs.
|
||||
if enable_force_load_balance and not self.use_aclgraph:
|
||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||
|
||||
return unified_fused_experts_eager(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
expert_map=expert_map,
|
||||
shared_experts=shared_experts,
|
||||
mc2_mask=kwargs.get(
|
||||
"mc2_mask", None),
|
||||
with_quant=False)
|
||||
|
||||
|
||||
class AscendFusedMoE(FusedMoE):
|
||||
|
||||
# The moe_counter parameter is required during the initialization of EPLB
|
||||
# to identify the current layer index within the MOE model.
|
||||
moe_counter = -1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int, # Global number of experts
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = False,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
ep_size: Optional[int] = None,
|
||||
dp_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
):
|
||||
# TODO: This could not initialize FusedMoE baseclass,
|
||||
# fixme and make __init__() of AscendFusedMoE more clear
|
||||
super().__init__(
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
params_dtype=params_dtype,
|
||||
reduce_results=reduce_results,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
ep_size=ep_size,
|
||||
dp_size=dp_size,
|
||||
prefix=prefix,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
AscendFusedMoE.moe_counter += 1
|
||||
self.moe_instance_id = AscendFusedMoE.moe_counter
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
self.moe_parallel_config = FusedMoEParallelConfig.make(
|
||||
tp_size_=(tp_size if tp_size is not None else
|
||||
get_tensor_model_parallel_world_size()),
|
||||
dp_size_=(dp_size
|
||||
if dp_size is not None else get_dp_group().world_size),
|
||||
vllm_parallel_config=vllm_config.parallel_config)
|
||||
|
||||
self.top_k = top_k
|
||||
self.num_experts = num_experts
|
||||
self.global_num_experts = num_experts
|
||||
assert intermediate_size % self.tp_size == 0
|
||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||
self.reduce_results = reduce_results
|
||||
self.renormalize = renormalize
|
||||
self.use_grouped_topk = use_grouped_topk
|
||||
if self.use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.custom_routing_function = custom_routing_function
|
||||
self.scoring_func = scoring_func
|
||||
self.e_score_correction_bias = e_score_correction_bias
|
||||
self.expert_map = None
|
||||
self.activation = activation
|
||||
self.log2phy = None
|
||||
self.global_redundant_expert_num = 0
|
||||
|
||||
is_deepseek_v3_r1 = self.global_num_experts == 256
|
||||
self.rm_router_logits = get_rm_router_logits_state(
|
||||
self.moe_parallel_config.ep_size, self.dp_size, is_deepseek_v3_r1)
|
||||
self.all_reduce_merge = get_all_reduce_merge_state(
|
||||
self.moe_parallel_config.ep_size, is_deepseek_v3_r1)
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
expert_map_path = ascend_config.expert_map_path
|
||||
if expert_map_path and os.path.exists(expert_map_path):
|
||||
# moe expert load balance
|
||||
expert_load_balancer = ExpertLoadBalancer(expert_map_path,
|
||||
self.global_num_experts)
|
||||
self.local_num_experts, self.expert_map = \
|
||||
expert_load_balancer.get_rank_placement_map(
|
||||
self.moe_instance_id,
|
||||
get_ep_group().rank_in_group)
|
||||
self.log2phy = expert_load_balancer.get_rank_log2phy_map(
|
||||
self.moe_instance_id,
|
||||
get_ep_group().rank_in_group)
|
||||
self.global_redundant_expert_num = \
|
||||
expert_load_balancer.get_global_redundant_expert_num()
|
||||
else:
|
||||
# Create a tensor of size num_experts filled with -1
|
||||
self.local_num_experts, self.expert_map = determine_expert_map(
|
||||
self.ep_size,
|
||||
get_ep_group().rank_in_group, self.global_num_experts)
|
||||
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
|
||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||
raise ValueError("Only softmax scoring function is supported for "
|
||||
"non-grouped topk.")
|
||||
moe = FusedMoEConfig.make(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
# TODO (bnell): this needs to be fixed for quantized types.
|
||||
in_dtype=params_dtype,
|
||||
quant_config=quant_config)
|
||||
|
||||
self.moe_config = moe
|
||||
|
||||
if quant_config is None:
|
||||
self.quant_method = AscendUnquantizedFusedMoEMethod(moe)
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||
|
||||
assert self.quant_method is not None
|
||||
|
||||
local_num_experts = torch.sum(self.expert_map != -1) \
|
||||
if self.expert_map is not None else num_experts
|
||||
|
||||
moe_quant_params = {
|
||||
"num_experts": local_num_experts,
|
||||
"hidden_size": hidden_size,
|
||||
"intermediate_size_per_partition":
|
||||
self.intermediate_size_per_partition,
|
||||
"params_dtype": params_dtype,
|
||||
"weight_loader": self.weight_loader,
|
||||
}
|
||||
# need full intermediate size pre-sharding for WNA16 act order
|
||||
if (self.quant_method.__class__.__name__
|
||||
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
|
||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||
|
||||
self.ep_group = get_ep_group()
|
||||
# NOTE: self.tp_group is not expert_tp_group
|
||||
self.tp_group = get_tp_group().device_group
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
self.token_dispatcher = None
|
||||
|
||||
ep_size = (get_ep_group().world_size if
|
||||
vllm_config.parallel_config.enable_expert_parallel else 1)
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
|
||||
setup_token_dispatchers
|
||||
setup_token_dispatchers(
|
||||
ep_size,
|
||||
top_k=self.top_k,
|
||||
num_experts=self.global_num_experts,
|
||||
num_global_redundant_experts=self.global_redundant_expert_num,
|
||||
num_local_experts=self.local_num_experts)
|
||||
|
||||
def naive_multicast(self, x: torch.Tensor,
|
||||
cu_tokens_across_dp_cpu: torch.Tensor):
|
||||
assert (len(x.shape) == 2)
|
||||
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
self.dp_rank - 1]
|
||||
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||
buffer[start:end, :].copy_(x)
|
||||
for idx in range(self.dp_size):
|
||||
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
|
||||
end = cu_tokens_across_dp_cpu[idx]
|
||||
get_dp_group().broadcast(buffer[start:end, :], idx)
|
||||
return buffer
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_prefill: bool,
|
||||
enable_force_load_balance: bool = False,
|
||||
top_k: Optional[int] = None,
|
||||
shared_experts: Optional[Any] = None,
|
||||
gate=None,
|
||||
replace_allreduce: bool = False,
|
||||
_metadata_for_padding: Optional[MetadataForPadding] = None):
|
||||
|
||||
assert self.quant_method is not None
|
||||
|
||||
if top_k:
|
||||
real_top_k = top_k
|
||||
else:
|
||||
real_top_k = self.top_k
|
||||
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
|
||||
forward_context = get_forward_context()
|
||||
fused_moe_state = forward_context.fused_moe_state
|
||||
mc2_mask = forward_context.mc2_mask
|
||||
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
|
||||
quantized_x_for_share, dynamic_scale_for_share = None, None
|
||||
|
||||
if shared_experts:
|
||||
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
||||
shared_hidden_states = shared_experts(hidden_states)
|
||||
|
||||
mc2_mask = forward_context.mc2_mask
|
||||
|
||||
enable_sp = _metadata_for_padding is not None and _metadata_for_padding.not_dummy_and_is_prefill
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if enable_sp:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
mc2_mask_sp = _metadata_for_padding.mc2_mask if _metadata_for_padding is not None else forward_context.mc2_mask
|
||||
chunk_mc2_mask = torch.tensor_split(mc2_mask_sp, tp_size, dim=0)
|
||||
mc2_mask = chunk_mc2_mask[tp_rank]
|
||||
replace_allreduce = True
|
||||
|
||||
if (fused_moe_state not in [
|
||||
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
||||
FusedMoEState.NaiveMulticast
|
||||
] and not replace_allreduce):
|
||||
if fused_moe_state in {FusedMoEState.MC2}:
|
||||
padding_size = forward_context.padded_num_tokens
|
||||
else:
|
||||
# TODO: Determine if we can remove the padding
|
||||
padding_size = tp_size
|
||||
if num_tokens < padding_size and not self.enable_shared_expert_dp:
|
||||
hidden_states = nn.functional.pad(
|
||||
hidden_states, (0, 0, 0, padding_size - num_tokens))
|
||||
router_logits = nn.functional.pad(
|
||||
router_logits, (0, 0, 0, padding_size - num_tokens))
|
||||
if tp_size > 1:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
if not self.enable_shared_expert_dp:
|
||||
chunk_hidden_states = torch.tensor_split(hidden_states,
|
||||
tp_size,
|
||||
dim=0)
|
||||
chunk_router_logits = torch.tensor_split(router_logits,
|
||||
tp_size,
|
||||
dim=0)
|
||||
hidden_states = chunk_hidden_states[tp_rank]
|
||||
router_logits = chunk_router_logits[tp_rank]
|
||||
|
||||
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
|
||||
mc2_mask = chunk_mc2_mask[tp_rank]
|
||||
|
||||
if self.dp_size > 1:
|
||||
if fused_moe_state == FusedMoEState.AllGather:
|
||||
# NOTE: When in torchair graph, it has been padded in model_runner_v1
|
||||
max_tokens_across_dp = forward_context.max_tokens_across_dp
|
||||
if num_tokens < max_tokens_across_dp:
|
||||
hidden_states = nn.functional.pad(
|
||||
hidden_states,
|
||||
(0, 0, 0, max_tokens_across_dp - num_tokens))
|
||||
if not self.rm_router_logits:
|
||||
router_logits = nn.functional.pad(
|
||||
router_logits,
|
||||
(0, 0, 0, max_tokens_across_dp - num_tokens))
|
||||
hidden_states = get_dp_group().all_gather(hidden_states, 0)
|
||||
if self.rm_router_logits:
|
||||
router_logits, _ = gate(hidden_states)
|
||||
else:
|
||||
router_logits = get_dp_group().all_gather(router_logits, 0)
|
||||
|
||||
elif fused_moe_state == FusedMoEState.NaiveMulticast:
|
||||
cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_dp_cpu
|
||||
hidden_states = self.naive_multicast(hidden_states,
|
||||
cu_tokens_across_dp_cpu)
|
||||
if self.rm_router_logits:
|
||||
router_logits, _ = gate(hidden_states)
|
||||
else:
|
||||
router_logits = self.naive_multicast(
|
||||
router_logits, cu_tokens_across_dp_cpu)
|
||||
|
||||
# Matrix multiply.
|
||||
e_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=real_top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
is_prefill=is_prefill,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
log2phy=self.log2phy,
|
||||
global_redundant_expert_num=self.global_redundant_expert_num,
|
||||
shared_experts=None,
|
||||
mc2_mask=mc2_mask,
|
||||
token_dispatcher=self.token_dispatcher,
|
||||
quantized_x_for_share=quantized_x_for_share,
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
)
|
||||
|
||||
if shared_experts:
|
||||
if isinstance(e_hidden_states, tuple):
|
||||
e_hidden_states, shared_hidden_states = e_hidden_states
|
||||
|
||||
if (fused_moe_state not in [
|
||||
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
||||
FusedMoEState.NaiveMulticast
|
||||
] and not replace_allreduce and not self.enable_shared_expert_dp):
|
||||
if tp_size > 1:
|
||||
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
|
||||
self.tp_group)
|
||||
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
||||
dispose_tensor(e_hidden_states)
|
||||
else:
|
||||
final_hidden_states = e_hidden_states
|
||||
if num_tokens < padding_size:
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
elif self.dp_size > 1 and not self.enable_shared_expert_dp:
|
||||
if fused_moe_state == FusedMoEState.NaiveMulticast:
|
||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
self.dp_rank - 1]
|
||||
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||
final_hidden_states = get_dp_group().all_reduce(
|
||||
e_hidden_states)
|
||||
final_hidden_states = final_hidden_states[start:end, :]
|
||||
dispose_tensor(e_hidden_states)
|
||||
elif fused_moe_state == FusedMoEState.AllGather:
|
||||
final_hidden_states = data_parallel_reduce_scatter(
|
||||
e_hidden_states, dim=0)
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
dispose_tensor(e_hidden_states)
|
||||
else:
|
||||
final_hidden_states = e_hidden_states
|
||||
else:
|
||||
final_hidden_states = e_hidden_states
|
||||
|
||||
if tp_size > 1 and not self.all_reduce_merge and fused_moe_state in [
|
||||
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
||||
FusedMoEState.NaiveMulticast
|
||||
]:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
if shared_experts:
|
||||
return final_hidden_states, shared_hidden_states
|
||||
else:
|
||||
return final_hidden_states
|
||||
|
||||
# ----------------------------------------- TBO-related --------------------------------------------
|
||||
|
||||
def _forward_ms_fused_moe_comp(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_prefill: bool,
|
||||
real_top_k,
|
||||
enable_force_load_balance: bool = False,
|
||||
):
|
||||
hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=real_top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
is_prefill=is_prefill,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
85
vllm_ascend/ops/layernorm.py
Normal file
85
vllm_ascend/ops/layernorm.py
Normal file
@@ -0,0 +1,85 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
|
||||
class AddRMSNormW8A8Quant(RMSNorm):
|
||||
# Fuse AddRmsNorm and W8A8 quantization ops together
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
layer: torch.nn.Module,
|
||||
eps: float = 1e-6,
|
||||
var_hidden_size: Optional[int] = None,
|
||||
has_weight: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
|
||||
self.layer = layer
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
import torch_npu
|
||||
|
||||
if residual is not None:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
||||
x,
|
||||
residual,
|
||||
self.weight,
|
||||
self.layer.aclnn_input_scale,
|
||||
self.layer.aclnn_input_offset,
|
||||
epsilon=self.variance_epsilon)
|
||||
return x, residual
|
||||
|
||||
x, residual = torch_npu.npu_rms_norm(x, self.weight,
|
||||
self.variance_epsilon)
|
||||
return x
|
||||
|
||||
|
||||
class AscendRMSNorm(RMSNorm):
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
if residual is not None:
|
||||
if is_310p():
|
||||
orig_dtype = residual.dtype
|
||||
x = x + residual.to(x.dtype)
|
||||
residual = x.to(orig_dtype)
|
||||
x, _ = torch_npu.npu_rms_norm(x, self.weight,
|
||||
self.variance_epsilon)
|
||||
else:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm(
|
||||
x, residual, self.weight, self.variance_epsilon)
|
||||
return x, residual
|
||||
|
||||
x, residual = torch_npu.npu_rms_norm(x, self.weight,
|
||||
self.variance_epsilon)
|
||||
return x
|
||||
0
vllm_ascend/ops/layers/__init__.py
Normal file
0
vllm_ascend/ops/layers/__init__.py
Normal file
283
vllm_ascend/ops/layers/experts_selector.py
Normal file
283
vllm_ascend/ops/layers/experts_selector.py
Normal file
@@ -0,0 +1,283 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
|
||||
def return_row_idx(hidden_states, top_k):
|
||||
num_tokens = hidden_states.shape[0]
|
||||
row_idx_len = num_tokens * top_k
|
||||
row_idx = (torch.arange(0,
|
||||
row_idx_len,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device).view(
|
||||
top_k, -1).permute(1, 0).contiguous())
|
||||
return row_idx
|
||||
|
||||
|
||||
def select_experts(hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor=1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
indices_type: Optional[torch.dtype] = None,
|
||||
global_num_experts: int = -1):
|
||||
"""
|
||||
Fused experts with select experts.
|
||||
|
||||
Args:
|
||||
router_logits: router logits of shape (num_tokens, hidden_size).
|
||||
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
||||
top_k: number of top k experts.
|
||||
use_grouped_topk: Whether to group experts before selecting top-k.
|
||||
renormalize: Whether to renormalize the routing weights.
|
||||
topk_group: Number of expert groups to select from.
|
||||
num_expert_group: Number of experts in each group.
|
||||
custom_routing_function: Custom routing function.
|
||||
scoring_func: Scoring function to use.
|
||||
e_score_correction_bias: Correction bias to apply to expert scores.
|
||||
indices_type: dtype of indices
|
||||
global_num_experts: Global number of experts.
|
||||
|
||||
Returns:
|
||||
topk_weights: router weights of shape (num_tokens, top_k).
|
||||
topk_ids: selected expert IDs of shape (num_tokens, top_k).
|
||||
"""
|
||||
|
||||
topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
renormalize=renormalize,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
if topk_weights is None:
|
||||
topk_weights, topk_ids = _native_select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts,
|
||||
)
|
||||
if row_idx is None:
|
||||
row_idx = return_row_idx(hidden_states, top_k)
|
||||
return topk_weights, topk_ids, row_idx
|
||||
|
||||
|
||||
def _native_grouped_topk(
|
||||
topk_weights: torch.Tensor,
|
||||
num_expert_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
):
|
||||
topk_group = 0 if topk_group is None else topk_group
|
||||
num_expert_group = 0 if num_expert_group is None else num_expert_group
|
||||
|
||||
num_token = topk_weights.shape[0]
|
||||
grouped_weights = topk_weights.view(num_token, num_expert_group,
|
||||
-1).max(dim=-1).values
|
||||
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
|
||||
k=topk_group,
|
||||
dim=-1,
|
||||
sorted=False)[1]
|
||||
topk_group_mask = torch.zeros_like(grouped_weights)
|
||||
topk_group_mask.scatter_(1, topk_group_indices, 1)
|
||||
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
|
||||
num_token, num_expert_group,
|
||||
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
|
||||
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
|
||||
|
||||
return topk_weights
|
||||
|
||||
|
||||
def _renormalize_topk_weights(
|
||||
topk_weights: torch.Tensor,
|
||||
renormalize: bool,
|
||||
):
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
return topk_weights
|
||||
|
||||
|
||||
def _select_expert_use_group_topk(
|
||||
topk_weights: torch.Tensor, topk_group: Optional[int],
|
||||
renormalize: bool, top_k: int, num_expert_group: Optional[int],
|
||||
e_score_correction_bias: Optional[torch.Tensor]):
|
||||
assert topk_group is not None
|
||||
assert num_expert_group is not None
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
# Store original scores before applying correction bias. We use biased
|
||||
# scores for expert selection but original scores for routing weights
|
||||
original_weights = topk_weights
|
||||
topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
|
||||
|
||||
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
|
||||
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
|
||||
topk_weights = _native_grouped_topk(topk_weights, num_expert_group,
|
||||
topk_group)
|
||||
# TODO bfloat16 is not supported in torch.topk with ge graph.
|
||||
if e_score_correction_bias is not None:
|
||||
topk_ids = torch.topk(topk_weights.to(torch.float32),
|
||||
k=top_k,
|
||||
dim=-1,
|
||||
sorted=False)[1]
|
||||
# Use original unbiased scores for the routing weights
|
||||
topk_weights = original_weights.gather(1, topk_ids)
|
||||
else:
|
||||
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
|
||||
k=top_k,
|
||||
dim=-1,
|
||||
sorted=False)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def _select_experts_with_fusion_ops(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
e_score_correction_bias: Optional[torch.Tensor],
|
||||
topk_group: Optional[int],
|
||||
num_expert_group: Optional[int],
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor=1.0,
|
||||
global_num_experts: int = -1):
|
||||
|
||||
topk_weights, topk_ids, row_idx = None, None, None
|
||||
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
|
||||
is_deepseek_v3_r1 = global_num_experts == 256
|
||||
if is_deepseek_v3_r1:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=top_k, # topk currently 8
|
||||
bias=e_score_correction_bias,
|
||||
k_group=topk_group, # fix: 4
|
||||
group_count=num_expert_group, # fix 8
|
||||
group_select_mode=
|
||||
1, # 0: the maximum in the group; 1: topk2.sum(fix)
|
||||
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
||||
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
||||
# out_flag=False, # todo new api; should the third output be output
|
||||
# y2_flag=False, # old api; should the third output be output
|
||||
routed_scaling_factor=1,
|
||||
eps=float(1e-20))
|
||||
row_idx = return_row_idx(hidden_states, top_k)
|
||||
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax":
|
||||
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
|
||||
x=router_logits, finished=None, k=top_k)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
|
||||
|
||||
return topk_weights, topk_ids, row_idx
|
||||
|
||||
|
||||
def _native_select_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
global_num_experts: Optional[torch.Tensor] = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Select top-k experts based on router logits.
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
||||
router_logits: Router logits of shape (num_tokens, num_experts).
|
||||
top_k: Number of experts to select.
|
||||
use_grouped_topk: Whether to group experts before selecting top-k.
|
||||
renormalize: Whether to renormalize the routing weights.
|
||||
topk_group: Number of expert groups to select from.
|
||||
num_expert_group: Number of experts in each group.
|
||||
custom_routing_function: Custom routing function.
|
||||
scoring_func: Scoring function to use.
|
||||
e_score_correction_bias: Correction bias to apply to expert scores.
|
||||
|
||||
Returns:
|
||||
topk_weights: Routing weights of shape (num_tokens, top_k).
|
||||
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
||||
|
||||
Raises:
|
||||
ValueError: If an unsupported scoring function is provided.
|
||||
"""
|
||||
|
||||
if scoring_func == "softmax":
|
||||
topk_weights = router_logits.softmax(dim=-1)
|
||||
elif scoring_func == "sigmoid":
|
||||
topk_weights = router_logits.sigmoid()
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
if use_grouped_topk:
|
||||
return _select_expert_use_group_topk(
|
||||
topk_weights=topk_weights,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
if custom_routing_function is not None:
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
global_num_experts=global_num_experts)
|
||||
# Required by npu_moe_init_routing
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
|
||||
topk_weights = topk_weights.to(hidden_states.dtype)
|
||||
|
||||
# Required by npu_moe_init_routing
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
199
vllm_ascend/ops/layers/moe_mlp.py
Normal file
199
vllm_ascend/ops/layers/moe_mlp.py
Normal file
@@ -0,0 +1,199 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||
from vllm_ascend.utils import dispose_tensor, is_310p
|
||||
|
||||
|
||||
def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
group_list_type: int = 1,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None) -> torch.Tensor:
|
||||
if dynamic_scale is None:
|
||||
unquantized_hidden_states = hidden_states
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
# Dispose the original unquantized hidden states
|
||||
# to save npu memory because they're no longer used.
|
||||
dispose_tensor(unquantized_hidden_states)
|
||||
else:
|
||||
pertoken_scale = dynamic_scale
|
||||
|
||||
bias1, bias2 = None, None
|
||||
_output_dtype = w2_scale.dtype
|
||||
|
||||
is_mc2 = get_forward_context().fused_moe_state == FusedMoEState.MC2
|
||||
if w1_scale_bias is None and is_mc2:
|
||||
w1_scale = w1_scale.to(torch.float32)
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=3,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=torch.int32)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight_scale=w1_scale,
|
||||
activation_scale=pertoken_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=group_list,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
scale=[w2_scale],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=w2_scale.dtype)[0]
|
||||
else:
|
||||
if w1_scale_bias is not None:
|
||||
if group_list_type == 0:
|
||||
group_list = torch.cat(
|
||||
[group_list[:1],
|
||||
torch.diff(group_list, dim=0)])
|
||||
group_list_type = 1
|
||||
bias1 = [w1_scale_bias]
|
||||
bias2 = [w2_scale_bias]
|
||||
# TODO w4a8 scene: dynamic acquisition of dtype in the future
|
||||
_output_dtype = torch.bfloat16
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
scale=[w1_scale],
|
||||
bias=bias1,
|
||||
per_token_scale=[pertoken_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
scale=[w2_scale],
|
||||
bias=bias2,
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unquant_apply_mlp(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
topk_scales: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
w1 = w1.transpose(1, 2)
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
if is_310p():
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
||||
torch.float16)
|
||||
else:
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
|
||||
if topk_scales is not None:
|
||||
gate_up_out *= topk_scales
|
||||
|
||||
w2 = w2.transpose(1, 2)
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[gate_up_out],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unified_apply_mlp(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
group_list_type: int = 1,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
topk_scales: Optional[torch.Tensor] = None,
|
||||
with_quant: bool = False) -> torch.Tensor:
|
||||
if with_quant:
|
||||
return quant_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list_type=group_list_type,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias)
|
||||
else:
|
||||
return unquant_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type,
|
||||
topk_scales=topk_scales)
|
||||
309
vllm_ascend/ops/linear.py
Normal file
309
vllm_ascend/ops/linear.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""
|
||||
Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
This file is a part of the vllm-ascend project.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
|
||||
ColumnParallelLinear,
|
||||
LinearBase,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization.base_config import \
|
||||
QuantizationConfig
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import (
|
||||
get_mlp_tensor_model_parallel_rank,
|
||||
get_mlp_tensor_model_parallel_world_size, get_mlp_tp_group)
|
||||
|
||||
|
||||
class AscendMlpColumnParallelLinear(ColumnParallelLinear):
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
Use the MLP tensor parallelism group in the MLP module,
|
||||
and the original TP group in other modules.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
output_sizes: Optional[list[int]] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
):
|
||||
# Divide the weight matrix along the last dimension.
|
||||
if prefix.find("gate_up_proj") != -1:
|
||||
self.tp_size = get_mlp_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_mlp_tensor_model_parallel_rank()
|
||||
self.enable_mlp_optimze = True
|
||||
else:
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.enable_mlp_optimze = False
|
||||
self.input_size_per_partition = input_size
|
||||
self.output_size_per_partition = divide(output_size, self.tp_size)
|
||||
self.output_partition_sizes = [self.output_size_per_partition]
|
||||
# If QKV or MergedColumn, use output size of each partition.
|
||||
if hasattr(self, "output_sizes"):
|
||||
self.output_partition_sizes = [
|
||||
divide(output_size, self.tp_size)
|
||||
for output_size in self.output_sizes
|
||||
]
|
||||
LinearBase.__init__(self,
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias)
|
||||
|
||||
self.gather_output = gather_output
|
||||
|
||||
if output_sizes is None:
|
||||
output_sizes = [output_size]
|
||||
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size_per_partition,
|
||||
output_partition_sizes=self.output_partition_sizes,
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=(
|
||||
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
||||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
|
||||
class AscendMlpRowParallelLinear(RowParallelLinear):
|
||||
"""Linear layer with row parallelism.
|
||||
Use the MLP tensor parallelism group in the MLP module,
|
||||
and the original TP group in other modules.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
input_is_parallel: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
):
|
||||
if prefix.find("down_proj") != -1:
|
||||
self.tp_size = get_mlp_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_mlp_tensor_model_parallel_rank()
|
||||
self.enable_mlp_optimze = True
|
||||
else:
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.enable_mlp_optimze = False
|
||||
# Divide the weight matrix along the first dimension.
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
self.output_size_per_partition = output_size
|
||||
self.output_partition_sizes = [output_size]
|
||||
|
||||
LinearBase.__init__(self,
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias)
|
||||
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size_per_partition,
|
||||
output_partition_sizes=self.output_partition_sizes,
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=(
|
||||
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
||||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
"results can lead to incorrect results")
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size, dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if self.enable_mlp_optimze:
|
||||
tp_rank = get_mlp_tensor_model_parallel_rank()
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = get_mlp_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
bias_ = None if (self.tp_rank > 0
|
||||
or self.skip_bias_add) else self.bias
|
||||
output_parallel = self.quant_method.apply(self,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
output = get_mlp_tp_group().reduce_scatter(output_parallel, 0)
|
||||
# output = output[:num_tokens,:]
|
||||
# dispose_tensor(output_parallel)
|
||||
else:
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
bias_ = None if (self.tp_rank > 0
|
||||
or self.skip_bias_add) else self.bias
|
||||
output_parallel = self.quant_method.apply(self,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class AscendMlpMergedColumnParallelLinear(MergedColumnParallelLinear):
|
||||
"""Packed linear layers with column parallelism.
|
||||
|
||||
Similar to ColumnParallelLinear, but the weight matrix is concatenated
|
||||
along the output dimension. When the weight matrix is loaded, the
|
||||
different partitions are sharded separately.
|
||||
|
||||
Use the MLP tensor parallelism group in the MLP module,
|
||||
and the original TP group in other modules.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_sizes: list[int],
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
):
|
||||
self.output_sizes = output_sizes
|
||||
if prefix.find("gate_up_proj") != -1:
|
||||
self.tp_size = get_mlp_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_mlp_tensor_model_parallel_rank()
|
||||
self.enable_mlp_optimze = True
|
||||
else:
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.enable_mlp_optimze = False
|
||||
assert all(output_size % self.tp_size == 0
|
||||
for output_size in output_sizes)
|
||||
AscendMlpColumnParallelLinear.__init__(self,
|
||||
input_size=input_size,
|
||||
output_size=sum(output_sizes),
|
||||
bias=bias,
|
||||
gather_output=gather_output,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
return_bias=return_bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
# self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
if self.enable_mlp_optimze:
|
||||
input2_ = get_mlp_tp_group().all_gather(input_, 0)
|
||||
output = self.quant_method.apply(self, input2_, bias)
|
||||
else:
|
||||
output_parallel = self.quant_method.apply(self, input_, bias)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
0
vllm_ascend/ops/moe_dispatcher/__init__.py
Normal file
0
vllm_ascend/ops/moe_dispatcher/__init__.py
Normal file
809
vllm_ascend/ops/moe_dispatcher/token_dispatcher.py
Normal file
809
vllm_ascend/ops/moe_dispatcher/token_dispatcher.py
Normal file
@@ -0,0 +1,809 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.distributed.tensor_parallel import \
|
||||
gather_from_sequence_parallel_region
|
||||
from vllm_ascend.ops.comm_utils import async_all_to_all
|
||||
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
|
||||
|
||||
_Dispatchers: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def _register_token_dispatcher(dispatcher: Any):
|
||||
_Dispatchers[dispatcher.__class__.__name__] = dispatcher
|
||||
|
||||
|
||||
def get_token_dispatcher(name: str):
|
||||
return _Dispatchers.get(name)
|
||||
|
||||
|
||||
def setup_token_dispatchers(ep_size: int, **kwargs):
|
||||
existing_dispatchers = set(_Dispatchers.keys())
|
||||
|
||||
if ep_size == 1 and "TokenDispatcherWithAllGather" not in existing_dispatchers:
|
||||
_register_token_dispatcher(TokenDispatcherWithAllGather(**kwargs))
|
||||
elif ep_size < 16 and "TokenDispatcherWithAll2AllV" not in existing_dispatchers:
|
||||
_register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs))
|
||||
elif ep_size >= 16:
|
||||
if "TokenDispatcherWithAll2AllV" not in existing_dispatchers:
|
||||
_register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs))
|
||||
if "TokenDispatcherWithMC2" not in existing_dispatchers:
|
||||
_register_token_dispatcher(TokenDispatcherWithMC2(**kwargs))
|
||||
|
||||
|
||||
class MoETokenDispatcher(ABC):
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
"""
|
||||
Initialize the MoE Token Dispatcher.
|
||||
"""
|
||||
self.top_k = kwargs.get("top_k", 0)
|
||||
self.num_experts = kwargs.get("num_experts", 0)
|
||||
|
||||
@property
|
||||
def ep_group(self):
|
||||
"""Get expert model parallel group."""
|
||||
return get_ep_group().device_group
|
||||
|
||||
@property
|
||||
def ep_rank(self):
|
||||
return get_ep_group().rank_in_group
|
||||
|
||||
@property
|
||||
def ep_size(self):
|
||||
return get_ep_group().world_size
|
||||
|
||||
@abstractmethod
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[torch.Tensor] = None,
|
||||
shared_gate_up: Optional[torch.Tensor] = None,
|
||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False):
|
||||
raise NotImplementedError("Dispatch function not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
raise NotImplementedError("Combine function not implemented.")
|
||||
|
||||
|
||||
class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
device_group = get_mc2_group().device_group
|
||||
# TODO: Try local_rank = ep_group.rank_in_group
|
||||
local_rank = torch.distributed.get_rank(group=device_group)
|
||||
backend = device_group._get_backend(torch.device("npu"))
|
||||
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
|
||||
self.ep_rank_id = get_mc2_group().rank_in_group
|
||||
self.ep_world_size = get_mc2_group().world_size
|
||||
self.enable_dispatch_v2 = hasattr(torch_npu,
|
||||
"npu_moe_distribute_dispatch_v2")
|
||||
self.need_extra_args = (
|
||||
get_ascend_soc_version() == AscendSocVersion.A3)
|
||||
|
||||
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
|
||||
self.a3_need_extra_args = \
|
||||
get_ascend_soc_version() == AscendSocVersion.A3
|
||||
self.output = None
|
||||
self.assist_info_for_combine = None
|
||||
self.ep_recv_counts = None
|
||||
self.shared_act = None
|
||||
self.topk_ids = None
|
||||
self.topk_weights = None
|
||||
self.shared_experts = None
|
||||
self.mc2_mask = None
|
||||
self.with_quant = False
|
||||
|
||||
def get_dispatch_mc2_kwargs(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
global_redundant_expert_num: int = 0,
|
||||
):
|
||||
if self.with_quant:
|
||||
quant_mode = 2
|
||||
if (expert_map is not None):
|
||||
moe_expert_num = len(expert_map) + global_redundant_expert_num
|
||||
else:
|
||||
moe_expert_num = global_redundant_expert_num
|
||||
else:
|
||||
quant_mode = 0
|
||||
moe_expert_num = len(expert_map)
|
||||
kwargs_mc2 = {
|
||||
"x": hidden_states,
|
||||
"expert_ids": topk_ids,
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": moe_expert_num,
|
||||
"global_bs": 0,
|
||||
}
|
||||
|
||||
stage1_kwargs = {
|
||||
"scales": None,
|
||||
"quant_mode": quant_mode,
|
||||
"group_ep": self.moe_all_to_all_group_name,
|
||||
"ep_world_size": self.ep_world_size,
|
||||
"ep_rank_id": self.ep_rank_id,
|
||||
}
|
||||
if self.need_extra_args:
|
||||
stage1_kwargs.update({
|
||||
"group_tp": self.moe_all_to_all_group_name,
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
if self.a3_need_extra_args and self.enable_dispatch_v2:
|
||||
stage1_kwargs.update({
|
||||
"x_active_mask": self.mc2_mask,
|
||||
})
|
||||
|
||||
kwargs_mc2.update(stage1_kwargs)
|
||||
return kwargs_mc2
|
||||
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[torch.Tensor] = None,
|
||||
shared_gate_up: Optional[torch.Tensor] = None,
|
||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False):
|
||||
self.with_quant = with_quant
|
||||
self.expert_map = expert_map
|
||||
self.topk_ids = topk_ids
|
||||
self.topk_weights = topk_weights
|
||||
self.shared_experts = shared_experts
|
||||
self.mc2_mask = mc2_mask
|
||||
|
||||
kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights,
|
||||
topk_ids, expert_map,
|
||||
global_redundant_expert_num)
|
||||
self.output = torch_npu.npu_moe_distribute_dispatch_v2(
|
||||
**kwargs_mc2
|
||||
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
|
||||
**kwargs_mc2)
|
||||
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||
expand_x, dynamic_scale, self.assist_info_for_combine, \
|
||||
expert_token_nums, self.ep_recv_counts = self.output[0:5]
|
||||
|
||||
if self.with_quant:
|
||||
if shared_experts is not None:
|
||||
shared_act_out = shared_experts.act_fn(
|
||||
(shared_gate_up, shared_dequant_scale))
|
||||
self.shared_act, self.swiglu_out_scale = \
|
||||
shared_act_out[0], shared_act_out[1]
|
||||
|
||||
else:
|
||||
if shared_experts is not None:
|
||||
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
|
||||
self.shared_act = shared_experts.act_fn(shared_gate_up)
|
||||
group_list_type = 1
|
||||
return {
|
||||
"group_list_type": group_list_type,
|
||||
"hidden_states": expand_x,
|
||||
"group_list": expert_token_nums,
|
||||
"dynamic_scale": dynamic_scale,
|
||||
}
|
||||
|
||||
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor):
|
||||
assert self.expert_map is not None
|
||||
assert self.topk_weights is not None
|
||||
assert self.topk_ids is not None
|
||||
assert self.output is not None
|
||||
moe_expert_num = len(self.expert_map)
|
||||
# moeCombine
|
||||
kwargs_mc2 = {
|
||||
"expand_x": hidden_states,
|
||||
"expert_ids": self.topk_ids,
|
||||
"expert_scales": self.topk_weights.to(torch.float32),
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": moe_expert_num,
|
||||
"global_bs": 0,
|
||||
}
|
||||
if self.with_quant:
|
||||
tp_recv_counts = torch.empty(1,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
else:
|
||||
tp_recv_counts = self.output[5]
|
||||
stage3_kwargs = {
|
||||
"ep_send_counts": self.ep_recv_counts,
|
||||
"group_ep": self.moe_all_to_all_group_name,
|
||||
"ep_world_size": self.ep_world_size,
|
||||
"ep_rank_id": self.ep_rank_id,
|
||||
}
|
||||
if self.enable_dispatch_v2:
|
||||
stage3_kwargs.update({
|
||||
"assist_info_for_combine":
|
||||
self.assist_info_for_combine,
|
||||
})
|
||||
else:
|
||||
stage3_kwargs.update({
|
||||
"expand_idx": self.assist_info_for_combine,
|
||||
})
|
||||
if self.need_extra_args:
|
||||
stage3_kwargs.update({
|
||||
"tp_send_counts": tp_recv_counts,
|
||||
"group_tp": self.moe_all_to_all_group_name,
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
if self.a3_need_extra_args and self.enable_dispatch_v2:
|
||||
stage3_kwargs.update({
|
||||
"x_active_mask": self.mc2_mask,
|
||||
})
|
||||
kwargs_mc2.update(stage3_kwargs)
|
||||
return kwargs_mc2
|
||||
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states)
|
||||
hidden_states = torch_npu.npu_moe_distribute_combine_v2(
|
||||
**kwargs_mc2
|
||||
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
|
||||
**kwargs_mc2)
|
||||
if self.shared_experts is None:
|
||||
return hidden_states
|
||||
else:
|
||||
if self.with_quant:
|
||||
shared_hidden_states, _ = self.shared_experts.down_proj(
|
||||
(self.shared_act, self.swiglu_out_scale))
|
||||
else:
|
||||
shared_hidden_states, _ = self.shared_experts.down_proj(
|
||||
self.shared_act)
|
||||
return hidden_states, shared_hidden_states
|
||||
|
||||
|
||||
class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.apply_router_weight_on_input = False
|
||||
self.max_num_tokens = kwargs.get("max_num_tokens")
|
||||
self.num_experts_local = kwargs.get("num_local_experts", 0)
|
||||
self.sorted_weights = None
|
||||
self.expanded_row_idx = None
|
||||
self.sorted_token_indices = None
|
||||
self.original_shape = None
|
||||
self.mask = None
|
||||
self.expert_map = None
|
||||
self.topk_weights = None
|
||||
self.topk_ids = None
|
||||
self.with_quant = False
|
||||
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[torch.Tensor] = None,
|
||||
shared_gate_up: Optional[torch.Tensor] = None,
|
||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False):
|
||||
self.with_quant = with_quant
|
||||
self.original_shape = hidden_states.shape
|
||||
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
dtype = hidden_states.dtype
|
||||
device = hidden_states.device
|
||||
self.expert_map = expert_map
|
||||
self.topk_weights = topk_weights
|
||||
self.topk_ids = topk_ids
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
if self.apply_router_weight_on_input:
|
||||
assert (topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert (
|
||||
topk == 1
|
||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
hidden_states = hidden_states * \
|
||||
topk_weights.to(hidden_states.dtype)
|
||||
|
||||
if expert_map is not None:
|
||||
# Generate token indices and flatten
|
||||
token_indices = (torch.arange(
|
||||
num_tokens, device=device,
|
||||
dtype=torch.int64).unsqueeze(1).expand(-1,
|
||||
self.top_k).reshape(-1))
|
||||
|
||||
# Flatten token-to-expert mappings and map to local experts
|
||||
weights_flat = topk_weights.view(-1)
|
||||
experts_flat = topk_ids.view(-1)
|
||||
local_experts_flat = expert_map[experts_flat]
|
||||
|
||||
# Filter valid token-expert pairs
|
||||
self.mask = local_experts_flat != -1
|
||||
filtered_weights = torch.where(
|
||||
self.mask, weights_flat,
|
||||
torch.zeros_like(weights_flat)).to(dtype)
|
||||
filtered_experts = torch.where(
|
||||
self.mask, local_experts_flat,
|
||||
torch.full_like(local_experts_flat,
|
||||
self.num_experts_local)).to(topk_ids.dtype)
|
||||
|
||||
# Sort by local expert IDs
|
||||
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
|
||||
self.sorted_token_indices = token_indices[sort_indices]
|
||||
self.sorted_weights = filtered_weights[sort_indices]
|
||||
|
||||
# Compute token counts with minlength of num_experts
|
||||
# This is equivalent to but faster than:
|
||||
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
|
||||
token_counts = torch.zeros(self.num_experts_local + 1,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
|
||||
token_counts.scatter_add_(0, filtered_experts.to(torch.int64),
|
||||
ones)
|
||||
token_counts = token_counts[:self.num_experts_local]
|
||||
|
||||
# Rearrange hidden_states
|
||||
sorted_hidden_states = hidden_states[self.sorted_token_indices]
|
||||
if self.with_quant:
|
||||
group_list_type = 1
|
||||
expert_tokens = token_counts
|
||||
else:
|
||||
expert_tokens = torch.cumsum(token_counts,
|
||||
dim=0,
|
||||
dtype=torch.int64)
|
||||
group_list_type = 0
|
||||
else:
|
||||
active_num = self.max_num_tokens if self.max_num_tokens is not None else num_tokens
|
||||
sorted_hidden_states, self.expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
row_idx=row_idx,
|
||||
expert_idx=topk_ids,
|
||||
active_num=active_num)
|
||||
|
||||
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
||||
expanded_expert_idx, self.num_experts_local)
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
group_list_type = 0
|
||||
return {
|
||||
"group_list_type": group_list_type,
|
||||
"hidden_states": sorted_hidden_states,
|
||||
"group_list": expert_tokens,
|
||||
}
|
||||
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
assert self.original_shape is not None
|
||||
dtype = hidden_states.dtype
|
||||
device = hidden_states.device
|
||||
if self.expert_map is not None:
|
||||
assert self.mask is not None
|
||||
assert self.sorted_token_indices is not None
|
||||
assert self.sorted_weights is not None
|
||||
|
||||
weighted_down_out = hidden_states * \
|
||||
self.sorted_weights.unsqueeze(1)
|
||||
|
||||
final_hidden_states = torch.zeros(*self.original_shape,
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
# TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
||||
# This created multiple NaN and index_add_ will mix them up which harms accuracy
|
||||
# remove this mask and filter after it being fixed
|
||||
num_valid_tokens = self.mask.sum()
|
||||
valid_token_mask = torch.arange(
|
||||
0, self.sorted_token_indices.shape[0],
|
||||
device=device).unsqueeze(1) < num_valid_tokens
|
||||
valid_output = torch.where(
|
||||
valid_token_mask, weighted_down_out,
|
||||
torch.zeros_like(weighted_down_out)).to(dtype)
|
||||
final_hidden_states.index_add_(0, self.sorted_token_indices,
|
||||
valid_output)
|
||||
else:
|
||||
if self.with_quant:
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
hidden_states,
|
||||
skip1=None,
|
||||
skip2=None,
|
||||
bias=None,
|
||||
scales=self.topk_weights,
|
||||
expanded_src_to_dst_row=self.expanded_row_idx,
|
||||
export_for_source_row=self.topk_ids,
|
||||
)
|
||||
if len(self.original_shape) == 3:
|
||||
final_hidden_states = final_hidden_states.view(
|
||||
self.original_shape)
|
||||
else:
|
||||
scales = torch.ones_like(
|
||||
self.topk_weights
|
||||
) if self.apply_router_weight_on_input else self.topk_weights
|
||||
# TODO: Reorder device memory 2 times here, replace the current
|
||||
# implementation here when suitable operators become available.
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
hidden_states,
|
||||
skip1=None,
|
||||
skip2=None,
|
||||
bias=None,
|
||||
scales=scales,
|
||||
expanded_src_to_dst_row=self.expanded_row_idx,
|
||||
export_for_source_row=self.topk_ids,
|
||||
)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
# mypy: disable-error-code="override"
|
||||
class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.apply_router_weight_on_input = False
|
||||
self.local_ep = 1
|
||||
self.local_num_experts = self.num_experts // self.local_ep
|
||||
self.local_num_group = self.top_k // self.local_ep
|
||||
self.bsz = None
|
||||
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[torch.Tensor] = None,
|
||||
shared_gate_up: Optional[torch.Tensor] = None,
|
||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False):
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
if self.apply_router_weight_on_input:
|
||||
assert (topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert (
|
||||
topk == 1
|
||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
hidden_states = hidden_states * \
|
||||
topk_weights.to(hidden_states.dtype)
|
||||
|
||||
self.bsz, _ = hidden_states.shape
|
||||
flatten_topk_ids = topk_ids.view(-1)
|
||||
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
||||
self.sorted_topk_ids = self.sorted_topk_ids.to(torch.int32)
|
||||
sorted_hidden_states = hidden_states.index_select(
|
||||
0, self.sorted_topk_ids // self.local_num_group)
|
||||
|
||||
experts_id = torch.arange(0,
|
||||
self.local_num_experts,
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device)
|
||||
num_tokens_per_expert = (
|
||||
flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
||||
torch.float32).sum(0)
|
||||
topk_scales = topk_weights.view(-1).index_select(
|
||||
0, self.sorted_topk_ids).unsqueeze(-1)
|
||||
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
||||
group_list_type = 0
|
||||
return {
|
||||
"group_list_type": group_list_type,
|
||||
"hidden_states": sorted_hidden_states,
|
||||
"group_list": group_list,
|
||||
"topk_scales": topk_scales,
|
||||
}
|
||||
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to(
|
||||
torch.int32)
|
||||
unsorted_hidden_states = hidden_states.index_select(
|
||||
0, unsorted_topk_ids)
|
||||
final_hidden_states = unsorted_hidden_states.reshape(
|
||||
self.bsz, self.top_k // self.local_ep, -1).sum(1)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
"""
|
||||
The implementation of the AlltoAll-based token dispatcher, which handles token
|
||||
dispatching on the sequence level instead of token level. The core of this implementation
|
||||
lies in each device dispatching on the entire sequence, with the hidden state being partitioned.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.with_quant = False
|
||||
self.num_local_experts = kwargs.get("num_local_experts", 0)
|
||||
self.num_global_redundant_experts = kwargs.get(
|
||||
"num_global_redundant_experts", 0)
|
||||
self.num_experts = self.num_experts + self.num_global_redundant_experts
|
||||
|
||||
self.hidden_shape = None
|
||||
self.topk_weights = None
|
||||
self.input_splits = None
|
||||
self.output_splits = None
|
||||
self.hidden_shape_before_permute = None
|
||||
|
||||
# [tp_ep_size * ep_size, num_local_experts]. Represents the number of tokens sent
|
||||
# to each local expert by all ranks.
|
||||
self.num_global_tokens_per_local_expert = None
|
||||
|
||||
# cached intermediate tensors.
|
||||
self.tokens_per_expert = None
|
||||
self.global_input_tokens_local_experts_indices = None
|
||||
|
||||
assert self.num_local_experts > 0, "Expected at least one expert"
|
||||
if self.num_local_experts > 1:
|
||||
self.expert_ids_per_ep_rank = torch.tensor(
|
||||
[i % self.num_local_experts for i in range(self.num_experts)],
|
||||
dtype=torch.int32,
|
||||
device=torch.npu.current_device(),
|
||||
)
|
||||
|
||||
local_expert_indices_offset = (self.ep_rank * self.num_local_experts)
|
||||
|
||||
self.local_expert_indices = [
|
||||
local_expert_indices_offset + i
|
||||
for i in range(self.num_local_experts)
|
||||
]
|
||||
assert (len(self.local_expert_indices) == self.num_local_experts
|
||||
), "Invalid local expert indices"
|
||||
for i in range(len(self.local_expert_indices) - 1):
|
||||
assert (self.local_expert_indices[i] ==
|
||||
self.local_expert_indices[i + 1] -
|
||||
1), "local_expert_indices must be continuous"
|
||||
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[torch.Tensor] = None,
|
||||
shared_gate_up: Optional[torch.Tensor] = None,
|
||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False):
|
||||
self.with_quant = with_quant
|
||||
self.hidden_shape = hidden_states.shape
|
||||
self.topk_weights = topk_weights
|
||||
assert topk_weights.dim() == 2, "Expected 2D tensor for topk_weights"
|
||||
assert topk_ids.dim() == 2, "Expected 2D tensor for routing map"
|
||||
|
||||
if log2phy is not None:
|
||||
topk_ids = log2phy[topk_ids]
|
||||
|
||||
permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert = self._dispatch_preprocess(
|
||||
hidden_states, topk_ids)
|
||||
self.reversed_local_input_permutation_mapping = reversed_local_input_permutation_mapping
|
||||
|
||||
dynamic_scale_after_all2all = None
|
||||
if self.with_quant:
|
||||
permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant(
|
||||
permutated_local_input_tokens)
|
||||
|
||||
_, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all(
|
||||
dynamic_scale,
|
||||
self.output_splits,
|
||||
self.input_splits,
|
||||
self.ep_group,
|
||||
)
|
||||
permute2_ep_all_to_all_handle.wait()
|
||||
dynamic_scale.untyped_storage().resize_(0)
|
||||
|
||||
_, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all(
|
||||
permutated_local_input_tokens,
|
||||
self.output_splits,
|
||||
self.input_splits,
|
||||
self.ep_group,
|
||||
)
|
||||
permute1_ep_all_to_all_handle.wait()
|
||||
permutated_local_input_tokens.untyped_storage().resize_(0)
|
||||
|
||||
global_input_tokens, dynamic_scale = self._dispatch_postprocess(
|
||||
global_input_tokens, dynamic_scale_after_all2all)
|
||||
return {
|
||||
"hidden_states": global_input_tokens,
|
||||
"group_list": tokens_per_expert,
|
||||
"dynamic_scale": dynamic_scale,
|
||||
"group_list_type": 1
|
||||
}
|
||||
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
|
||||
|
||||
hidden_states = self._combine_preprocess(hidden_states)
|
||||
|
||||
# Perform expert parallel AlltoAll communication
|
||||
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
|
||||
_, permutated_local_input_tokens, handle = async_all_to_all(
|
||||
hidden_states, self.input_splits, self.output_splits,
|
||||
self.ep_group)
|
||||
handle.wait()
|
||||
hidden_states.untyped_storage().resize_(0)
|
||||
|
||||
output = self._combine_postprocess(permutated_local_input_tokens)
|
||||
|
||||
self.input_splits = None
|
||||
self.output_splits = None
|
||||
self.num_global_tokens_per_local_expert = None
|
||||
|
||||
return output
|
||||
|
||||
def _dispatch_preprocess(self, hidden_states, topk_ids):
|
||||
assert self.hidden_shape is not None
|
||||
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
|
||||
tokens_per_expert = self._preprocess(topk_ids)
|
||||
|
||||
self.hidden_shape_before_permute = hidden_states.shape
|
||||
|
||||
permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute(
|
||||
tokens=hidden_states,
|
||||
indices=topk_ids,
|
||||
num_out_tokens=self.num_out_tokens,
|
||||
)
|
||||
return permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert
|
||||
|
||||
def _preprocess(self, topk_ids: torch.Tensor) -> torch.Tensor:
|
||||
num_local_tokens_per_expert = torch.histc(topk_ids,
|
||||
bins=self.num_experts,
|
||||
min=0,
|
||||
max=self.num_experts)
|
||||
|
||||
ep_size = self.ep_size
|
||||
|
||||
# Dropless
|
||||
self.num_out_tokens = topk_ids.numel()
|
||||
|
||||
# ===================================================
|
||||
# Calculate input_splits, output_splits for alltoall-v.
|
||||
# ===================================================
|
||||
self.input_splits = (num_local_tokens_per_expert.reshape(
|
||||
ep_size,
|
||||
self.num_local_experts).sum(axis=1).to(torch.device("cpu"),
|
||||
non_blocking=True).numpy())
|
||||
num_global_tokens_per_expert = gather_from_sequence_parallel_region(
|
||||
num_local_tokens_per_expert,
|
||||
group=self.ep_group).reshape(ep_size, self.num_experts)
|
||||
self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[
|
||||
0]:self.local_expert_indices[-1] + 1]
|
||||
if self.num_global_tokens_per_local_expert is None:
|
||||
raise ValueError(
|
||||
"num_global_tokens_per_local_expert must be set before sum.")
|
||||
self.output_splits = (self.num_global_tokens_per_local_expert.sum(
|
||||
axis=-1).to(torch.device("cpu"), non_blocking=True).numpy())
|
||||
num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum(
|
||||
axis=0)
|
||||
# ===================================================
|
||||
# num_global_tokens_per_expert: [ep_size, num_experts]
|
||||
# num_global_tokens_per_local_expert: [ep_size, num_local_experts]
|
||||
# num_tokens_per_local_expert: [num_local_experts]
|
||||
# ===================================================
|
||||
|
||||
if self.num_local_experts > 1:
|
||||
if self.num_global_tokens_per_local_expert is None:
|
||||
raise ValueError(
|
||||
"num_global_tokens_per_local_expert must be set before operations."
|
||||
)
|
||||
self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
|
||||
self.expert_ids_per_ep_rank,
|
||||
self.num_global_tokens_per_local_expert.ravel())
|
||||
|
||||
return num_tokens_per_local_expert
|
||||
|
||||
def _dispatch_postprocess(self, global_input_tokens, dynamic_scale=None):
|
||||
# Early return if no local experts or no tokens
|
||||
if self.num_local_experts <= 1:
|
||||
return global_input_tokens, None
|
||||
|
||||
# Handle quantized case
|
||||
if self.with_quant:
|
||||
assert self.global_input_tokens_local_experts_indices is not None, \
|
||||
"global_input_tokens_local_experts_indices must be initialized before calling _dispatch_postprocess"
|
||||
expert_idx_2d = self.global_input_tokens_local_experts_indices.unsqueeze(
|
||||
-1)
|
||||
active_num = self.global_input_tokens_local_experts_indices.numel()
|
||||
|
||||
# Handle case with no active tokens
|
||||
if active_num <= 0:
|
||||
self.reversed_global_input_permutation_mapping = self.global_input_tokens_local_experts_indices
|
||||
return global_input_tokens, dynamic_scale
|
||||
|
||||
# Process with active tokens
|
||||
global_input_tokens, self.reversed_global_input_permutation_mapping, _, expanded_scale = torch_npu.npu_moe_init_routing_v2(
|
||||
global_input_tokens,
|
||||
expert_idx_2d,
|
||||
scale=dynamic_scale,
|
||||
active_num=active_num,
|
||||
expert_capacity=0,
|
||||
expert_num=self.num_local_experts,
|
||||
expert_tokens_num_type=1,
|
||||
expert_tokens_num_flag=True,
|
||||
active_expert_range=[0, self.num_local_experts],
|
||||
quant_mode=-1,
|
||||
row_idx_type=0)
|
||||
return global_input_tokens, expanded_scale
|
||||
|
||||
# Handle non-quantized case
|
||||
global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute(
|
||||
global_input_tokens,
|
||||
self.global_input_tokens_local_experts_indices)
|
||||
return global_input_tokens, None
|
||||
|
||||
def _combine_preprocess(self, hidden_states):
|
||||
# Unpermutation 2: expert output to AlltoAll input
|
||||
if hidden_states.shape[0] > 0 and self.num_local_experts > 1:
|
||||
hidden_states = torch_npu.npu_moe_token_unpermute(
|
||||
hidden_states, self.reversed_global_input_permutation_mapping)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def _combine_postprocess(self, permutated_local_input_tokens):
|
||||
# Unpermutation 1: AlltoAll output to output
|
||||
output = torch_npu.npu_moe_token_unpermute(
|
||||
permuted_tokens=permutated_local_input_tokens,
|
||||
sorted_indices=self.reversed_local_input_permutation_mapping.to(
|
||||
torch.int32),
|
||||
probs=self.topk_weights,
|
||||
restore_shape=self.hidden_shape_before_permute)
|
||||
|
||||
# Reshape the output tensor
|
||||
output = output.view(self.hidden_shape)
|
||||
return output
|
||||
339
vllm_ascend/ops/rotary_embedding.py
Normal file
339
vllm_ascend/ops/rotary_embedding.py
Normal file
@@ -0,0 +1,339 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import enable_custom_op, is_310p
|
||||
|
||||
|
||||
def _custom_rotary_embedding_enabled(query, neox_style, head_size):
|
||||
return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op(
|
||||
)
|
||||
|
||||
|
||||
def _rope_forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
is_neox_style_override: Optional[bool] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
query_shape, key_shape = query.shape, key.shape
|
||||
if self.cos_sin_cache.device != query.device:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
|
||||
if self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
||||
neox_style = self.is_neox_style
|
||||
if is_neox_style_override is not None:
|
||||
neox_style = is_neox_style_override
|
||||
# adopt custom kernel path for rotary_embedding
|
||||
if _custom_rotary_embedding_enabled(query, neox_style,
|
||||
self.head_size) and not is_310p():
|
||||
query, key = torch.ops._C.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
neox_style,
|
||||
)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
if offsets is not None:
|
||||
raise NotImplementedError(
|
||||
"Batched rotary embedding is currently not supported on NPU.")
|
||||
else:
|
||||
if self.rotary_dim < self.head_size:
|
||||
num_tokens = query.shape[0]
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
q_rot = query[..., :self.rotary_dim]
|
||||
q_pass = query[..., self.rotary_dim:]
|
||||
k_rot = key[..., :self.rotary_dim]
|
||||
k_pass = key[..., self.rotary_dim:]
|
||||
q_rot = q_rot.contiguous().view(num_tokens, -1)
|
||||
k_rot = k_rot.contiguous().view(num_tokens, -1)
|
||||
torch_npu._npu_rotary_embedding(
|
||||
positions,
|
||||
q_rot,
|
||||
k_rot,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
neox_style,
|
||||
)
|
||||
q_rot = q_rot.view(num_tokens, -1, self.rotary_dim)
|
||||
k_rot = k_rot.view(num_tokens, -1, self.rotary_dim)
|
||||
q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
|
||||
k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
|
||||
return q, k
|
||||
# TODO: Remove the contiguous in the future.
|
||||
query = query.contiguous().view(query.shape[0], -1)
|
||||
key = key.contiguous().view(key.shape[0], -1)
|
||||
torch_npu._npu_rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
neox_style,
|
||||
)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
|
||||
|
||||
class AscendRotaryEmbedding(RotaryEmbedding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style, dtype)
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
is_neox_style_override: Optional[bool] = None,
|
||||
):
|
||||
return _rope_forward_oot(
|
||||
self,
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
offsets,
|
||||
is_neox_style_override,
|
||||
)
|
||||
|
||||
|
||||
class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
mscale: float = 1,
|
||||
mscale_all_dim: float = 0,
|
||||
) -> None:
|
||||
# Note: we adopt the native huggingface deepseek rope initialization code from
|
||||
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for
|
||||
# its more ascend compute friendly
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
self.attn_factor = attn_factor
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
# Get n-d magnitude scaling corrected for interpolation.
|
||||
self.mscale = float(
|
||||
self._yarn_get_mscale(self.scaling_factor, float(mscale)) /
|
||||
self._yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
|
||||
attn_factor)
|
||||
super(DeepseekScalingRotaryEmbedding,
|
||||
self).__init__(head_size, rotary_dim, max_position_embeddings,
|
||||
base, is_neox_style, dtype)
|
||||
self.max_seq_len = max_position_embeddings
|
||||
self._set_cos_sin_cache(seq_len=max_position_embeddings,
|
||||
device=NPUPlatform.device_type,
|
||||
dtype=dtype)
|
||||
|
||||
def _yarn_get_mscale(self, scale: float = 1, mscale: float = 1) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
def _rotate_half(self, x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., :x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
def _yarn_linear_ramp_mask(self, min_value, max_value, dim):
|
||||
# Note: The if conditional branch is not used here
|
||||
# to solve MTP compilation error.
|
||||
max_value += (min_value == max_value).float() * 0.001
|
||||
linear_func = (torch.arange(dim, dtype=torch.float32) -
|
||||
min_value) / (max_value - min_value)
|
||||
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||
return ramp_func
|
||||
|
||||
# Inverse dim formula to find dim based on number of rotations
|
||||
def _yarn_find_correction_dim(self,
|
||||
num_rotations,
|
||||
dim,
|
||||
base=10000,
|
||||
max_position_embeddings=2048):
|
||||
# Note: use torch instead of math to solve MTP compilation error.
|
||||
return (dim * torch.log(
|
||||
torch.tensor(max_position_embeddings) /
|
||||
(num_rotations * 2 * torch.pi))) / (2 *
|
||||
torch.log(torch.tensor(base)))
|
||||
|
||||
# Find dim range bounds based on rotations
|
||||
def _yarn_find_correction_range(self,
|
||||
low_rot,
|
||||
high_rot,
|
||||
dim,
|
||||
base=10000,
|
||||
max_position_embeddings=2048):
|
||||
# Note: use torch instead of math to solve MTP compilation error.
|
||||
low = torch.floor(
|
||||
self._yarn_find_correction_dim(low_rot, dim, base,
|
||||
max_position_embeddings))
|
||||
high = torch.ceil(
|
||||
self._yarn_find_correction_dim(high_rot, dim, base,
|
||||
max_position_embeddings))
|
||||
# Note: use torch instead of max/min to solve MTP compilation error.
|
||||
return torch.clamp(low, min=0), torch.clamp(high, max=dim - 1)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def _apply_rotary_pos_emb(self,
|
||||
q,
|
||||
k,
|
||||
cos,
|
||||
sin,
|
||||
position_ids,
|
||||
unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`):
|
||||
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
||||
used to pass offsetted position ids when working with a KV-cache.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos[position_ids]
|
||||
sin = sin[position_ids]
|
||||
cos = cos[:, None, None, :]
|
||||
sin = sin[:, None, None, :]
|
||||
|
||||
if len(q.shape) == 3:
|
||||
q = q[:, :, None, :]
|
||||
if len(k.shape) == 2:
|
||||
k = k[:, None, None, :]
|
||||
elif len(k.shape) == 3:
|
||||
k = k[:, :, None, :]
|
||||
|
||||
b, h_q, s, d = q.shape
|
||||
q = q.view(b, h_q, s, d // 2, 2).transpose(4, 3).reshape(b, h_q, s, d)
|
||||
|
||||
b, h_k, s, d = k.shape
|
||||
k = k.view(b, h_k, s, d // 2, 2).transpose(4, 3).reshape(b, h_k, s, d)
|
||||
|
||||
q_embed = (q * cos) + (self._rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (self._rotate_half(k) * sin)
|
||||
|
||||
q_embed = q_embed.view(b, h_q, d)
|
||||
k_embed = k_embed.view(b, h_k, d)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
dim = self.rotary_dim
|
||||
|
||||
freq_extra = 1.0 / (self.base**(
|
||||
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
||||
freq_inter = 1.0 / (self.scaling_factor * self.base**(
|
||||
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
||||
|
||||
low, high = self._yarn_find_correction_range(
|
||||
self.beta_fast,
|
||||
self.beta_slow,
|
||||
dim,
|
||||
self.base,
|
||||
self.max_position_embeddings,
|
||||
)
|
||||
inv_freq_mask = 1.0 - self._yarn_linear_ramp_mask(
|
||||
low, high, dim // 2).to(device=device, dtype=torch.float32)
|
||||
inv_freq = freq_inter * (1 -
|
||||
inv_freq_mask) + freq_extra * inv_freq_mask
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
t = torch.arange(seq_len * self.scaling_factor,
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
|
||||
sin_cached = torch.cat([freqs, freqs], dim=-1).sin() * self.mscale
|
||||
cos_cached = cos_cached.to(dtype)
|
||||
sin_cached = sin_cached.to(dtype)
|
||||
cache = torch.cat(
|
||||
[freqs.cos() * self.mscale,
|
||||
freqs.sin() * self.mscale], dim=-1).to(dtype)
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
||||
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
||||
|
||||
def forward(self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
max_seq_len: Optional[int] = None):
|
||||
if max_seq_len is not None and max_seq_len > self.max_seq_len:
|
||||
self._set_cos_sin_cache(max_seq_len, query.device, query.dtype)
|
||||
if len(key.shape) == 2:
|
||||
key = key[:, None, :]
|
||||
# Note: we implement the non neox_style method with shuffle the last dim and neox style
|
||||
# calculation method which is also more compute friendly to the ascend machine
|
||||
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py
|
||||
neox_style = True
|
||||
if self.is_neox_style is False:
|
||||
b, h_q, d = query.shape
|
||||
query = query.view(b, h_q, d // 2,
|
||||
2).transpose(3, 2).reshape(b, h_q, d)
|
||||
b, h_k, d = key.shape
|
||||
key = key.view(b, h_k, d // 2, 2).transpose(3,
|
||||
2).reshape(b, h_k, d)
|
||||
q_pe, k_pe = _rope_forward_oot(self, positions, query, key, offsets,
|
||||
neox_style)
|
||||
return q_pe, k_pe
|
||||
120
vllm_ascend/ops/sequence_parallel.py
Normal file
120
vllm_ascend/ops/sequence_parallel.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
get_tp_group, tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_reduce_scatter)
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
|
||||
|
||||
class MetadataForPadding:
|
||||
|
||||
def __init__(self,
|
||||
padding_flag=False,
|
||||
lengths_sum_padding=0,
|
||||
lengths_sum_unpadding=0,
|
||||
pad_size=0,
|
||||
not_dummy_and_is_prefill=False):
|
||||
self.padding_flag = padding_flag
|
||||
self.not_dummy_and_is_prefill = not_dummy_and_is_prefill
|
||||
|
||||
self.lengths_sum_padding = lengths_sum_padding
|
||||
self.lengths_sum_unpadding = lengths_sum_unpadding
|
||||
self.pad_size = pad_size
|
||||
|
||||
self.tp_size = get_tp_group().world_size
|
||||
self.tp_rank_in_group = get_tp_group().rank_in_group
|
||||
|
||||
assert self.lengths_sum_padding % self.tp_size == 0
|
||||
self.slice_size = self.lengths_sum_padding // self.tp_size
|
||||
|
||||
self.mc2_mask = torch.zeros(
|
||||
self.lengths_sum_padding,
|
||||
dtype=torch.bool,
|
||||
device=NPUPlatform.device_type,
|
||||
)
|
||||
self.mc2_mask[:lengths_sum_unpadding] = True
|
||||
|
||||
def padding_aligned_reduce_scatter(self,
|
||||
data: torch.Tensor) -> torch.Tensor:
|
||||
if self.padding_flag:
|
||||
pad_size = self.pad_size
|
||||
padded_data = F.pad(data, (0, 0, 0, pad_size))
|
||||
else:
|
||||
padded_data = data
|
||||
padded_data_reduce_scatter = tensor_model_parallel_reduce_scatter(
|
||||
padded_data, 0)
|
||||
|
||||
return padded_data_reduce_scatter
|
||||
|
||||
def allgather_unpadding_aligned(self,
|
||||
padded_data: torch.Tensor) -> torch.Tensor:
|
||||
padded_data_allgather = tensor_model_parallel_all_gather(
|
||||
padded_data, 0)
|
||||
if self.padding_flag:
|
||||
lengths_sum_unpadding = self.lengths_sum_unpadding
|
||||
unpadding_data = padded_data_allgather[:lengths_sum_unpadding]
|
||||
else:
|
||||
unpadding_data = padded_data_allgather
|
||||
return unpadding_data
|
||||
|
||||
def padding_slice(self, data: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
padded_data = F.pad(data, (0, 0, 0, self.pad_size))
|
||||
start = self.tp_rank_in_group * self.slice_size
|
||||
end = start + self.slice_size
|
||||
slice_data = padded_data[start:end]
|
||||
|
||||
return slice_data
|
||||
|
||||
def padding_aligned_scatter(self, data: torch.Tensor) -> torch.Tensor:
|
||||
if self.padding_flag:
|
||||
pad_size = self.pad_size
|
||||
padded_data = F.pad(data, (0, 0, 0, pad_size))
|
||||
else:
|
||||
padded_data = data
|
||||
# padded_data = data
|
||||
padded_data = torch.tensor_split(padded_data, self.tp_size, dim=0)
|
||||
|
||||
padded_data_reduce_scatter = padded_data[self.tp_rank_in_group]
|
||||
|
||||
return padded_data_reduce_scatter
|
||||
|
||||
|
||||
def init_metadata_for_sp(input_ids, enable_sequence_parallelism):
|
||||
if not enable_sequence_parallelism:
|
||||
return MetadataForPadding(padding_flag=False,
|
||||
not_dummy_and_is_prefill=False)
|
||||
|
||||
is_perifll = 0
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if attn_metadata is not None:
|
||||
if hasattr(attn_metadata,
|
||||
'is_only_prefill') and attn_metadata.is_only_prefill:
|
||||
is_perifll = 1
|
||||
if hasattr(attn_metadata,
|
||||
'num_prefills') and attn_metadata.num_prefills > 0:
|
||||
is_perifll = 1
|
||||
|
||||
if is_perifll:
|
||||
lengths_sum_unpadding = input_ids.shape[0]
|
||||
lengths_sum_padding = (
|
||||
(lengths_sum_unpadding + tp_size - 1) // tp_size) * tp_size
|
||||
if lengths_sum_unpadding == lengths_sum_padding:
|
||||
padding_flag = False
|
||||
else:
|
||||
padding_flag = True
|
||||
pad_size = lengths_sum_padding - lengths_sum_unpadding
|
||||
_metadata_for_padding = MetadataForPadding(
|
||||
lengths_sum_unpadding=lengths_sum_unpadding,
|
||||
lengths_sum_padding=lengths_sum_padding,
|
||||
padding_flag=padding_flag,
|
||||
pad_size=pad_size,
|
||||
not_dummy_and_is_prefill=True)
|
||||
|
||||
return _metadata_for_padding
|
||||
|
||||
return MetadataForPadding(padding_flag=False,
|
||||
not_dummy_and_is_prefill=False)
|
||||
254
vllm_ascend/ops/vocab_parallel_embedding.py
Normal file
254
vllm_ascend/ops/vocab_parallel_embedding.py
Normal file
@@ -0,0 +1,254 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.distributed import divide, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, UnquantizedEmbeddingMethod,
|
||||
VocabParallelEmbedding, pad_vocab_size)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import get_lmhead_tp_group
|
||||
from vllm_ascend.utils import lmhead_tp_enable
|
||||
|
||||
|
||||
class AscendVocabParallelEmbedding(VocabParallelEmbedding):
|
||||
"""
|
||||
Register VocabParallelEmbedding as a custom op for Ascend.
|
||||
AscendVocabParallelEmbedding support different communication parallel groups
|
||||
Added the feature of lmheadTP in pure dp scenario
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
|
||||
if lmhead_tp_enable() and prefix.find("lm_head") != -1:
|
||||
self.comm_group = get_lmhead_tp_group()
|
||||
else:
|
||||
self.comm_group = get_tp_group()
|
||||
|
||||
self.tp_size = self.comm_group.world_size
|
||||
self.tp_rank = self.comm_group.rank_in_group
|
||||
|
||||
self.num_embeddings = num_embeddings
|
||||
self.padding_size = padding_size
|
||||
self.org_vocab_size = org_num_embeddings or num_embeddings
|
||||
num_added_embeddings = num_embeddings - self.org_vocab_size
|
||||
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
|
||||
self.padding_size)
|
||||
self.num_embeddings_padded = pad_vocab_size(
|
||||
self.org_vocab_size_padded + num_added_embeddings,
|
||||
self.padding_size)
|
||||
assert self.org_vocab_size_padded <= self.num_embeddings_padded
|
||||
|
||||
self.shard_indices = self._get_indices(self.num_embeddings_padded,
|
||||
self.org_vocab_size_padded,
|
||||
self.num_embeddings,
|
||||
self.org_vocab_size,
|
||||
self.tp_rank, self.tp_size)
|
||||
self.embedding_dim = embedding_dim
|
||||
quant_method = None
|
||||
if quant_config is not None:
|
||||
quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
if quant_method is None:
|
||||
quant_method = UnquantizedEmbeddingMethod()
|
||||
|
||||
# If we are making an embedding layer, then our quantization linear
|
||||
# method must implement the embedding operation. If we are another
|
||||
# layer type like ParallelLMHead, this is not important.
|
||||
is_embedding_layer = type(self) is VocabParallelEmbedding
|
||||
quant_method_implements_embedding = method_has_implemented_embedding(
|
||||
type(quant_method))
|
||||
if is_embedding_layer and not quant_method_implements_embedding:
|
||||
raise NotImplementedError(
|
||||
f"The class {type(quant_method).__name__} must implement "
|
||||
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
|
||||
|
||||
self.quant_method: QuantizeMethodBase = quant_method
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
# Divide the weight matrix along the vocaburaly dimension.
|
||||
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
|
||||
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
|
||||
self.tp_size)
|
||||
assert (self.shard_indices.num_elements_padded ==
|
||||
self.num_embeddings_per_partition)
|
||||
self.num_org_embeddings_per_partition = (
|
||||
self.shard_indices.org_vocab_end_index -
|
||||
self.shard_indices.org_vocab_start_index)
|
||||
self.num_added_embeddings_per_partition = (
|
||||
self.shard_indices.added_vocab_end_index -
|
||||
self.shard_indices.added_vocab_start_index)
|
||||
|
||||
self.quant_method.create_weights(self,
|
||||
self.embedding_dim,
|
||||
[self.num_embeddings_per_partition],
|
||||
self.embedding_dim,
|
||||
self.num_embeddings_padded,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
|
||||
def _get_masked_input_and_mask(
|
||||
self, input_: torch.Tensor, org_vocab_start_index: int,
|
||||
org_vocab_end_index: int, num_org_vocab_padding: int,
|
||||
added_vocab_start_index: int,
|
||||
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# torch.compile will fuse all of the pointwise ops below
|
||||
# into a single kernel, making it very fast
|
||||
org_vocab_mask = (input_ >= org_vocab_start_index) & (
|
||||
input_ < org_vocab_end_index)
|
||||
# Adapt: avoid create added_vocab_mask when added_vocab_start_index == added_vocab_end_index.
|
||||
if added_vocab_start_index == added_vocab_end_index:
|
||||
valid_offset = (org_vocab_start_index * org_vocab_mask)
|
||||
vocab_mask = org_vocab_mask
|
||||
else:
|
||||
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
||||
input_ < added_vocab_end_index)
|
||||
added_offset = added_vocab_start_index - (
|
||||
org_vocab_end_index -
|
||||
org_vocab_start_index) - num_org_vocab_padding
|
||||
valid_offset = (org_vocab_start_index *
|
||||
org_vocab_mask) + (added_offset * added_vocab_mask)
|
||||
vocab_mask = org_vocab_mask | added_vocab_mask
|
||||
# Adapt end.
|
||||
input_ = vocab_mask * (input_ - valid_offset)
|
||||
return input_, ~vocab_mask
|
||||
|
||||
def forward(self, input_):
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
masked_input, input_mask = self._get_masked_input_and_mask(
|
||||
input_, self.shard_indices.org_vocab_start_index,
|
||||
self.shard_indices.org_vocab_end_index,
|
||||
self.shard_indices.num_org_vocab_padding,
|
||||
self.shard_indices.added_vocab_start_index,
|
||||
self.shard_indices.added_vocab_end_index)
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = self.quant_method.embedding(self,
|
||||
masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
return output
|
||||
|
||||
|
||||
class AscendParallelLMHead(ParallelLMHead):
|
||||
"""
|
||||
Register ParallelLMHead as a custom op for Ascend."""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
bias: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
AscendVocabParallelEmbedding.__init__(self, num_embeddings,
|
||||
embedding_dim, params_dtype,
|
||||
org_num_embeddings, padding_size,
|
||||
quant_config, prefix)
|
||||
|
||||
self.quant_config = quant_config
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition,
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
|
||||
class AscendLogitsProcessor(LogitsProcessor):
|
||||
"""
|
||||
Register LogitsProcessor as a custom op for Ascend.
|
||||
Added the feature of lmheadTP in pure dp scenario
|
||||
"""
|
||||
|
||||
def _get_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
lm_head: AscendParallelLMHead,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
if lmhead_tp_enable():
|
||||
return self._get_logits_lmheadtp(hidden_states, lm_head,
|
||||
embedding_bias)
|
||||
else:
|
||||
return self._get_logits_normal(hidden_states, lm_head,
|
||||
embedding_bias)
|
||||
|
||||
def _get_logits_lmheadtp(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
lm_head: AscendParallelLMHead,
|
||||
embedding_bias: Optional[torch.Tensor],
|
||||
) -> Optional[torch.Tensor]:
|
||||
# Gather hidden states from all devices in tensor parallel group
|
||||
gathered_hidden_states = get_lmhead_tp_group().all_gather(
|
||||
hidden_states, dim=0)
|
||||
local_logits = lm_head.quant_method.apply(lm_head,
|
||||
gathered_hidden_states,
|
||||
bias=embedding_bias)
|
||||
# Gather logits for tensor parallel
|
||||
logits = get_lmhead_tp_group().all_to_all(local_logits)
|
||||
# Remove paddings in vocab (if any)
|
||||
if logits is not None:
|
||||
logits = logits[..., :self.org_vocab_size]
|
||||
return logits
|
||||
|
||||
def _get_logits_normal(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
lm_head: AscendParallelLMHead,
|
||||
embedding_bias: Optional[torch.Tensor],
|
||||
) -> Optional[torch.Tensor]:
|
||||
local_logits = lm_head.quant_method.apply(lm_head,
|
||||
hidden_states,
|
||||
bias=embedding_bias)
|
||||
# Gather logits for tensor parallel
|
||||
logits = self._gather_logits(local_logits)
|
||||
|
||||
# Remove paddings in vocab (if any)
|
||||
if logits is not None:
|
||||
logits = logits[..., :self.org_vocab_size]
|
||||
|
||||
return logits
|
||||
104
vllm_ascend/patch/__init__.py
Normal file
104
vllm_ascend/patch/__init__.py
Normal file
@@ -0,0 +1,104 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ----------------------------------------------------------------------------------
|
||||
# This module manage the patch for vllm. There are two folders in this module:
|
||||
# - platform: contains the patches applied before worker starts. It's called by
|
||||
# `vllm_ascend.utils.adapt_patch(is_global_patch=True)` in
|
||||
# `vllm_ascend.platform.NPUPlatform.pre_register_and_update()` function.
|
||||
# - worker: contains the patches applied when worker starts. It's called by
|
||||
# `vllm_ascend.utils.adapt_patch(is_global_patch=False)` in
|
||||
# each worker's `__init__` function.
|
||||
#
|
||||
# Then in each kind of patch, there are three folders:
|
||||
# - patch_0_10_0: contains the patches applied when vllm version is 0.10.0.
|
||||
# - patch_main: contains the patches applied when vllm version is main branch.
|
||||
# - patch_common: contains the patches applied in both 0.10.0 and main branch.
|
||||
#
|
||||
# Once a new patch is added in vllm-ascend, please add the patch description into this file as well.
|
||||
# ----------------------------------------------------------------------------------
|
||||
|
||||
# What's Patched and how it works:
|
||||
# --------------------------------
|
||||
# * Platform Patch:
|
||||
# =================
|
||||
# ** File: platform/patch_common/patch_distributed.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.config.ParallelConfig.get_next_dp_init_port`
|
||||
# Why:
|
||||
# vllm doesn't support get port from environment.
|
||||
# How:
|
||||
# Add the logic to get port from environment.
|
||||
# Related PR (if no, explain why):
|
||||
# Need a PR to vllm to support get port from environment.
|
||||
# Future Plan:
|
||||
# Remove those patch when vllm merged them
|
||||
#
|
||||
# * Worker Patch:
|
||||
# ===============
|
||||
# ** File: worker/patch_common/patch_minicpm.py **
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.model_executor.models.minicpm.MiniCPMAttention.forward`
|
||||
# Why:
|
||||
# The forward func of MiniCPMAttention in vllm do a datatype convert
|
||||
# (original datatype --> float32) to ensure the precision on cuda.
|
||||
# However float32 is not supported in cann rope op, thus we keep this patch
|
||||
# How:
|
||||
# Removed the dtype convert operations in forward
|
||||
# Related PR (if no, explain why):
|
||||
# NO, only for npu due to rope op.
|
||||
# Future Plan:
|
||||
# Keep this patch in vllm-ascend.
|
||||
#
|
||||
# ** File: worker/patch_common/patch_distributed.py **
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.distributed.parallel_state.GroupCoordinator`
|
||||
# Why:
|
||||
# vllm doesn't support all_to_all for GroupCoordinator.
|
||||
# How:
|
||||
# Add all_to_all implementation for GroupCoordinator.
|
||||
# Related PR (if no, explain why):
|
||||
# Need a PR to vllm to support all_to_all for GroupCoordinator.
|
||||
# Future Plan:
|
||||
# Remove this patch when vllm merged them.
|
||||
#
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.v1.sample.sampler.Sampler.gather_logprobs`
|
||||
# Why:
|
||||
# We need to patch gather_logprobs to make sure call batched_count_greater_than
|
||||
# with backend=current_platform.simple_compile_backend
|
||||
# How:
|
||||
# Patch gather_logprobs call new batched_count_greater_than
|
||||
# Related PR (if no, explain why):
|
||||
# - https://github.com/vllm-project/vllm/pull/21591
|
||||
# Future Plan:
|
||||
# Revert it when vLLM merge #21591 and release new version
|
||||
# ** File: worker/patch_common/patch_linear.py **
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.model_executor.layers.linear.RowParallelLinear`
|
||||
# Why:
|
||||
# We need to fuse matmul and allreuce in `RowParallelLinear`
|
||||
# to improve performance.
|
||||
# How:
|
||||
# Create a new class `AscendRowParallelLinear` that inherits from `RowParallelLinear`.
|
||||
# In this class, we override the `forward` method to use
|
||||
# torch_npu.npu_mm_all_reduce_base to replace matmul and allreduce.
|
||||
# Related PR (if no, explain why):
|
||||
# - https://github.com/vllm-project/vllm-ascend/pull/1926
|
||||
# Future Plan:
|
||||
# Validate more models in all kinds of scenario,
|
||||
# if performance is always improved, we can enable this patch by default and remove the env
|
||||
# variable `VLLM_ASCEND_ENABLE_FUSE_MATMUL_ALLREDUCE` in the future.
|
||||
18
vllm_ascend/patch/platform/__init__.py
Normal file
18
vllm_ascend/patch/platform/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from vllm_ascend.patch.platform import patch_common # noqa: F401
|
||||
from vllm_ascend.patch.platform import patch_main # noqa: F401
|
||||
18
vllm_ascend/patch/platform/patch_common/__init__.py
Normal file
18
vllm_ascend/patch/platform/patch_common/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa
|
||||
115
vllm_ascend/patch/platform/patch_common/patch_distributed.py
Normal file
115
vllm_ascend/patch/platform/patch_common/patch_distributed.py
Normal file
@@ -0,0 +1,115 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Adapted from vllm/model_executor/models/qwen2_vl.py
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
import torch
|
||||
import vllm.envs as envs_vllm
|
||||
from vllm.config import ParallelConfig
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
|
||||
|
||||
def parallel_config_get_dp_port(self) -> int:
|
||||
"""
|
||||
We might need to initialize process groups in multiple
|
||||
processes that is related to data parallelism,
|
||||
e.g. both in the worker and in the engine, which
|
||||
can live in different processes. To avoid port conflicts, we
|
||||
increment the port number each time we need to initialize a
|
||||
new process group related to data parallelism.
|
||||
"""
|
||||
answer = self.data_parallel_master_port
|
||||
self.data_parallel_master_port += 1
|
||||
|
||||
# NOTE: Get port from envs directly when using torchrun
|
||||
port = envs_vllm.VLLM_DP_MASTER_PORT if envs_vllm.VLLM_DP_MASTER_PORT else answer
|
||||
return port
|
||||
|
||||
|
||||
ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port
|
||||
|
||||
|
||||
class NullHandle:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def wait(self):
|
||||
pass
|
||||
|
||||
|
||||
def communication_adaptation_310p():
|
||||
|
||||
def broadcast310p_wrapper(fn):
|
||||
|
||||
def broadcast310p(tensor, src, group=None, async_op=False):
|
||||
if tensor.device == torch.device('cpu'):
|
||||
return fn(tensor, src, group, async_op)
|
||||
rank = torch.distributed.get_rank(group)
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
||||
tensor_list[rank] = tensor
|
||||
torch.distributed.all_gather(tensor_list, tensor, group=group)
|
||||
tensor[...] = tensor_list[src]
|
||||
if async_op:
|
||||
return NullHandle()
|
||||
else:
|
||||
return None
|
||||
|
||||
return broadcast310p
|
||||
|
||||
torch.distributed.broadcast = broadcast310p_wrapper(
|
||||
torch.distributed.broadcast)
|
||||
torch.distributed.distributed_c10d.broadcast = broadcast310p_wrapper(
|
||||
torch.distributed.distributed_c10d.broadcast)
|
||||
|
||||
def all_reduce_wrapper_310p(fn):
|
||||
|
||||
def all_reduce(
|
||||
tensor,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=None,
|
||||
async_op=False,
|
||||
):
|
||||
if tensor.dtype != torch.int64:
|
||||
return fn(tensor, op, group, async_op)
|
||||
rank = torch.distributed.get_rank(group)
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
||||
tensor_list[rank] = tensor
|
||||
torch.distributed.all_gather(tensor_list, tensor, group=group)
|
||||
if op == torch.distributed.ReduceOp.SUM:
|
||||
return torch.stack(tensor_list).sum(0)
|
||||
elif op == torch.distributed.ReduceOp.MAX:
|
||||
return torch.tensor(
|
||||
torch.stack(tensor_list).cpu().numpy().max(0),
|
||||
device=tensor.device,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"not implement op {op}")
|
||||
|
||||
return all_reduce
|
||||
|
||||
torch.distributed.all_reduce = all_reduce_wrapper_310p(
|
||||
torch.distributed.all_reduce)
|
||||
torch.distributed.distributed_c10d.all_reduce = all_reduce_wrapper_310p(
|
||||
torch.distributed.distributed_c10d.all_reduce)
|
||||
|
||||
|
||||
if is_310p():
|
||||
communication_adaptation_310p()
|
||||
16
vllm_ascend/patch/platform/patch_main/__init__.py
Normal file
16
vllm_ascend/patch/platform/patch_main/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
19
vllm_ascend/patch/worker/__init__.py
Normal file
19
vllm_ascend/patch/worker/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from vllm_ascend.patch.worker import patch_common # noqa: F401
|
||||
from vllm_ascend.patch.worker import patch_main # noqa: F401
|
||||
22
vllm_ascend/patch/worker/patch_common/__init__.py
Normal file
22
vllm_ascend/patch/worker/patch_common/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_linear # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_logits # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_lora_embedding # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
|
||||
49
vllm_ascend/patch/worker/patch_common/patch_distributed.py
Normal file
49
vllm_ascend/patch/worker/patch_common/patch_distributed.py
Normal file
@@ -0,0 +1,49 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import vllm
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
|
||||
|
||||
class GroupCoordinatorPatch(GroupCoordinator):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def all_to_all(self,
|
||||
input_: torch.Tensor,
|
||||
scatter_dim: int = 0,
|
||||
gather_dim: int = -1,
|
||||
scatter_sizes: Optional[List[int]] = None,
|
||||
gather_sizes: Optional[List[int]] = None) -> torch.Tensor:
|
||||
if self.world_size == 1:
|
||||
return input_
|
||||
assert -input_.dim() <= scatter_dim < input_.dim(), (
|
||||
f"Invalid scatter dim ({scatter_dim}) for input tensor with shape {input_.size()}"
|
||||
)
|
||||
assert -input_.dim() <= gather_dim < input_.dim(), (
|
||||
f"Invalid gather dim ({gather_dim}) for input tensor with shape {input_.size()}"
|
||||
)
|
||||
return self.device_communicator.all_to_all(input_, scatter_dim,
|
||||
gather_dim, scatter_sizes,
|
||||
gather_sizes)
|
||||
|
||||
|
||||
vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch # Note: check the GroupCoordinator with online serving
|
||||
147
vllm_ascend/patch/worker/patch_common/patch_linear.py
Normal file
147
vllm_ascend/patch/worker/patch_common/patch_linear.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""
|
||||
Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
This file is a part of the vllm-ascend project.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
import vllm
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
split_tensor_along_last_dim)
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
|
||||
_HCOMM_INFO = None
|
||||
|
||||
|
||||
class AscendRowParallelLinear(RowParallelLinear):
|
||||
"""
|
||||
AscendRowParallelLinear is a custom implementation of RowParallelLinear
|
||||
that overrides the forward method to handle Ascend-specific operations.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize the AscendRowParallelLinear layer.
|
||||
|
||||
Args:
|
||||
*args: Variable length argument list.
|
||||
**kwargs: Arbitrary keyword arguments.
|
||||
"""
|
||||
tp_group = get_tp_group().device_group
|
||||
hcomm_info = self.get_hcomm_info(tp_group)
|
||||
self.hcomm_info = hcomm_info
|
||||
super().__init__(*args, **kwargs)
|
||||
self.weight_t = self.weight.t()
|
||||
|
||||
@staticmethod
|
||||
def get_hcomm_info(group: ProcessGroup) -> str:
|
||||
"""Get the HCCL communication information for the given group.
|
||||
|
||||
Args:
|
||||
group (ProcessGroup): The process group for which to get the HCCL communication info.
|
||||
|
||||
Returns:
|
||||
str: The HCCL communication name for the given group.
|
||||
"""
|
||||
global _HCOMM_INFO
|
||||
if _HCOMM_INFO is not None:
|
||||
return _HCOMM_INFO
|
||||
|
||||
rank = torch.distributed.get_rank(group)
|
||||
if torch.__version__ > "2.0":
|
||||
global_rank = torch.distributed.get_global_rank(group, rank)
|
||||
_HCOMM_INFO = group._get_backend(
|
||||
torch.device("npu")).get_hccl_comm_name(global_rank)
|
||||
|
||||
else:
|
||||
_HCOMM_INFO = group.get_hccl_comm_name(rank)
|
||||
return _HCOMM_INFO
|
||||
|
||||
def forward(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
"""Forward pass for the AscendRowParallelLinear layer.
|
||||
|
||||
Args:
|
||||
input_ (torch.Tensor): the input tensor to the layer.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
The output tensor after applying the linear transformation,
|
||||
and optionally the bias if `return_bias` is True.
|
||||
"""
|
||||
input_parallel = self.calc_input(input_)
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
output = self.calc_output(input_parallel)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
def calc_input(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
"""Calculate the input tensor for parallel processing.
|
||||
|
||||
Args:
|
||||
input_ (torch.Tensor): the input tensor to be processed.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The input tensor split along the last dimension
|
||||
for tensor model parallelism, or the original input if not parallel.
|
||||
"""
|
||||
if self.input_is_parallel:
|
||||
return input_
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
return splitted_input[tp_rank].contiguous()
|
||||
|
||||
def calc_output(self, input_parallel: torch.Tensor) -> torch.Tensor:
|
||||
"""Calculate the output tensor of forward by considering
|
||||
fusing communication and computation.
|
||||
|
||||
Args:
|
||||
input_parallel (_type_): the input tensor to be processed in parallel.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the output tensor after applying the linear transformation
|
||||
and optionally handle communication between tensor model parallel ranks.
|
||||
"""
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output = torch_npu.npu_mm_all_reduce_base(input_parallel,
|
||||
self.weight_t,
|
||||
self.hcomm_info,
|
||||
bias=bias_)
|
||||
else:
|
||||
output = self.quant_method.apply(self, input_parallel, bias=bias_)
|
||||
return output
|
||||
|
||||
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE:
|
||||
logger.info("AscendRowParallelLinear: Matmul all-reduce is enabled. ")
|
||||
vllm.model_executor.layers.linear.RowParallelLinear = AscendRowParallelLinear
|
||||
26
vllm_ascend/patch/worker/patch_common/patch_logits.py
Normal file
26
vllm_ascend/patch/worker/patch_common/patch_logits.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import torch
|
||||
import vllm
|
||||
from vllm._custom_ops import apply_repetition_penalties_torch
|
||||
|
||||
|
||||
def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor,
|
||||
output_mask: torch.Tensor,
|
||||
repetition_penalties: torch.Tensor) -> None:
|
||||
"""Apply repetition penalties to logits in-place.
|
||||
|
||||
Args:
|
||||
logits: The logits tensor of shape [num_seqs, vocab_size].
|
||||
prompt_mask: A boolean tensor indicating which tokens appear in the prompt.
|
||||
output_mask: A boolean tensor indicating which tokens appear in the output.
|
||||
repetition_penalties: The repetition penalties of shape (num_seqs, ).
|
||||
"""
|
||||
apply_repetition_penalties_torch(logits, prompt_mask, output_mask,
|
||||
repetition_penalties)
|
||||
|
||||
|
||||
# NPU device type tensors have attributes is_cuda=True and is_npu=True, according to its implementation in
|
||||
# https://github.com/Ascend/pytorch/blob/863b9071cbdf47023c12c246e3efa9c6e2285fc6/torch_npu/npu/_stream_check.py#L74
|
||||
# This causes that vLLM's apply_repetition_penalties function will run into the branch of "if logits.is_cuda" and
|
||||
# call the custom op implemented in CUDA, which is not compatible with NPU.
|
||||
# Reference: https://github.com/vllm-project/vllm/blob/f66673a39d9f364194c249f28098cad8a5584ccb/vllm/_custom_ops.py#L314
|
||||
vllm._custom_ops.apply_repetition_penalties = apply_repetition_penalties
|
||||
@@ -0,0 +1,29 @@
|
||||
from typing import Optional
|
||||
|
||||
import vllm
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.lora.layers import VocabParallelEmbeddingWithLoRA
|
||||
from vllm.lora.utils import _all_lora_classes
|
||||
|
||||
from vllm_ascend.ops.vocab_parallel_embedding import \
|
||||
AscendVocabParallelEmbedding
|
||||
|
||||
|
||||
class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA):
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendVocabParallelEmbedding
|
||||
|
||||
|
||||
# Patch for lora register_model issue after overriding VocabParallelEmbedding class (#2515)
|
||||
_all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes = _all_lora_classes
|
||||
36
vllm_ascend/patch/worker/patch_common/patch_minicpm.py
Normal file
36
vllm_ascend/patch/worker/patch_common/patch_minicpm.py
Normal file
@@ -0,0 +1,36 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.models.minicpm import MiniCPMAttention
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
# The type conversion in the forward function is deleted to support the rope operator.
|
||||
MiniCPMAttention.forward = forward
|
||||
16
vllm_ascend/patch/worker/patch_main/__init__.py
Normal file
16
vllm_ascend/patch/worker/patch_main/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
345
vllm_ascend/platform.py
Normal file
345
vllm_ascend/platform.py
Normal file
@@ -0,0 +1,345 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import gc
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import vllm.envs as envs_vllm
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import PrefixStore
|
||||
from vllm.logger import logger
|
||||
from vllm.platforms import Platform, PlatformEnum
|
||||
|
||||
from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config,
|
||||
init_ascend_config)
|
||||
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
|
||||
delete_torchair_cache_file)
|
||||
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, is_310p,
|
||||
update_aclgraph_sizes)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
else:
|
||||
ModelConfig = None
|
||||
VllmConfig = None
|
||||
FlexibleArgumentParser = None
|
||||
|
||||
|
||||
class NPUPlatform(Platform):
|
||||
|
||||
_enum = PlatformEnum.OOT
|
||||
device_name: str = "npu"
|
||||
device_type: str = "npu"
|
||||
simple_compile_backend: str = "eager" # Disable torch.compile()
|
||||
ray_device_key: str = "NPU"
|
||||
device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES"
|
||||
dispatch_key: str = "PrivateUse1"
|
||||
|
||||
supported_quantization: list[str] = [ASCEND_QUANTIZATION_METHOD]
|
||||
|
||||
def is_sleep_mode_available(self) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def pre_register_and_update(cls,
|
||||
parser: Optional[FlexibleArgumentParser] = None
|
||||
) -> None:
|
||||
# Adapt the global patch here.
|
||||
from vllm_ascend.utils import adapt_patch
|
||||
adapt_patch(is_global_patch=True)
|
||||
|
||||
# For online serving, "ascend" quantization method is not a choice natively,
|
||||
# so we need to add "ascend" quantization method to quantization methods list
|
||||
# and the user can enable quantization using "vllm serve --quantization ascend".
|
||||
if parser is not None:
|
||||
quant_action = parser._option_string_actions.get('--quantization')
|
||||
if quant_action and hasattr(quant_action,
|
||||
'choices') and quant_action.choices:
|
||||
if ASCEND_QUANTIZATION_METHOD not in quant_action.choices:
|
||||
quant_action.choices.append(ASCEND_QUANTIZATION_METHOD)
|
||||
|
||||
from vllm_ascend.quantization.quant_config import \
|
||||
AscendQuantConfig # noqa: F401
|
||||
|
||||
@classmethod
|
||||
def get_device_capability(cls, device_id: int = 0):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
return torch.npu.get_device_name(device_id)
|
||||
|
||||
@classmethod
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
return torch.inference_mode()
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device):
|
||||
torch.npu.set_device(device)
|
||||
|
||||
@classmethod
|
||||
def empty_cache(cls):
|
||||
torch.npu.empty_cache()
|
||||
|
||||
@classmethod
|
||||
def synchronize(cls):
|
||||
torch.npu.synchronize()
|
||||
|
||||
@classmethod
|
||||
def mem_get_info(cls) -> Tuple[int, int]:
|
||||
return torch.npu.mem_get_info()
|
||||
|
||||
@classmethod
|
||||
def clear_npu_memory(cls):
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
if not envs_vllm.VLLM_USE_V1:
|
||||
raise ValueError("vLLM Ascend does not support V0 engine.")
|
||||
# initialize ascend config from vllm additional_config
|
||||
ascend_config = init_ascend_config(vllm_config)
|
||||
|
||||
from vllm.config import CompilationLevel # noqa: E402
|
||||
compilation_config = vllm_config.compilation_config
|
||||
model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
cache_config = vllm_config.cache_config
|
||||
kv_cache_dtype = vllm_config.additional_config.get(
|
||||
"kv_cache_dtype", None)
|
||||
if kv_cache_dtype is not None:
|
||||
vllm_config.cache_config.cache_dtype = kv_cache_dtype
|
||||
|
||||
if model_config is None:
|
||||
logger.warning("Model config is missing. This may indicate "
|
||||
"that we are running a test case")
|
||||
enforce_eager = False
|
||||
else:
|
||||
enforce_eager = getattr(model_config, "enforce_eager", False)
|
||||
|
||||
check_ascend_config(vllm_config, enforce_eager)
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
if enforce_eager:
|
||||
logger.info("Compilation disabled, using eager mode by default")
|
||||
compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
|
||||
compilation_config.cudagraph_num_of_warmups = 1
|
||||
|
||||
# TODO: make vllm support oot platform to set `compilation_config.cudagraph_mode`
|
||||
# if cudagraph_mode is not explicitly set by users, set default value
|
||||
if compilation_config.level == CompilationLevel.PIECEWISE:
|
||||
compilation_config.cudagraph_mode = \
|
||||
CUDAGraphMode.PIECEWISE
|
||||
elif compilation_config.level not in [
|
||||
CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE
|
||||
]:
|
||||
logger.warning(
|
||||
"NPU does not support %s compilation level. Setting CUDAGraphMode to NONE",
|
||||
compilation_config.level)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
else:
|
||||
logger.warning(
|
||||
"compilation_config.level = CompilationLevel.NO_COMPILATION is set, Setting CUDAGraphMode to NONE"
|
||||
)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
# set CUDAGraphMode to None when torchair is enabled, no mather what compilation_config.level is.
|
||||
if ascend_config.torchair_graph_config.enabled:
|
||||
logger.info(
|
||||
"Torchair compilation enabled on NPU. Setting CUDAGraphMode to NONE"
|
||||
)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
# Note: We delete the torchair cache folder here to prevent runtime issues caused by dimension
|
||||
# mismatches or configuration inconsistencies when users reuse cached computation graphs. Though
|
||||
# this will increase graph compilation duration, it significantly enhances robustness and decreases
|
||||
# graph launching time during inference.
|
||||
if check_torchair_cache_exist(
|
||||
) and not ascend_config.torchair_graph_config.use_cached_kv_cache_bytes:
|
||||
logger.warning(
|
||||
"Torchair cache folder is deleted here to prevent runtime issues caused by dimension "
|
||||
"mismatches or configuration inconsistencies when users reuse cached computation graphs. "
|
||||
"In order to decrease torchair graph compilation time, users can enable both use_cached_graph "
|
||||
"and use_cached_kv_cache_bytes in torchair_graph_config.")
|
||||
delete_torchair_cache_file()
|
||||
|
||||
if parallel_config.distributed_executor_backend == "ray":
|
||||
logger.warning(
|
||||
"Ray distributed executor backend is not compatible with ACL Graph mode "
|
||||
"right now. Setting CUDAGraphMode to NONE")
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
# set cudaprah sizes before extending `compilation_config.splitting_ops`
|
||||
vllm_config._set_cudagraph_sizes()
|
||||
|
||||
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
|
||||
compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
|
||||
logger.info(
|
||||
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
|
||||
"using only ACL Graph mode")
|
||||
assert compilation_config.level == CompilationLevel.PIECEWISE, \
|
||||
"When enabling piecewise aclgraph, please make sure compilation_config.level == CompilationLevel.PIECEWISE and compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE"
|
||||
compilation_config.set_splitting_ops_for_v1()
|
||||
compilation_config.use_inductor = False
|
||||
compilation_config.splitting_ops.extend(
|
||||
["vllm.unified_ascend_attention_with_output"])
|
||||
update_aclgraph_sizes(vllm_config)
|
||||
else:
|
||||
logger.info(
|
||||
"%s cudagraph_mode is not support on NPU. falling back to NONE",
|
||||
compilation_config.cudagraph_mode)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
|
||||
if parallel_config and parallel_config.worker_cls == "auto":
|
||||
if ascend_config.torchair_graph_config.enabled:
|
||||
parallel_config.worker_cls = "vllm_ascend.torchair.torchair_worker.NPUTorchairWorker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
|
||||
|
||||
if cache_config:
|
||||
if cache_config.block_size is None:
|
||||
cache_config.block_size = 128
|
||||
if cache_config.enable_prefix_caching and cache_config.block_size != 128:
|
||||
logger.warning(
|
||||
"If prefix caching is enabled, block size must be set to 128."
|
||||
)
|
||||
cache_config.block_size = 128
|
||||
|
||||
# Activate custom ops for v1, except on 310P
|
||||
if not is_310p():
|
||||
compilation_config.custom_ops = ["all"]
|
||||
|
||||
# If ascend_scheduler_config is enabled,
|
||||
# extents original scheduler_config to use AscendScheduler.
|
||||
if ascend_config.ascend_scheduler_config.enabled:
|
||||
from vllm_ascend.core.schedule_config import AscendSchedulerConfig
|
||||
ascend_scheduler_config = AscendSchedulerConfig.initialize_from_config(
|
||||
vllm_config.scheduler_config,
|
||||
ascend_config.ascend_scheduler_config)
|
||||
vllm_config.scheduler_config = ascend_scheduler_config
|
||||
|
||||
if compilation_config.pass_config.enable_sequence_parallelism:
|
||||
if not parallel_config.enable_expert_parallel or vllm_config.model_config.hf_config.model_type != "qwen3_moe":
|
||||
raise NotImplementedError(
|
||||
"For better performance in Qwen3 MoE, SP only works exclusively with MC2, AllToAll, and AllToAllV."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls,
|
||||
selected_backend,
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_v1,
|
||||
use_mla,
|
||||
has_sink=False):
|
||||
if not use_v1:
|
||||
raise ValueError("vLLM Ascend does not support V0 engine.")
|
||||
|
||||
use_torchair = get_ascend_config().torchair_graph_config.enabled
|
||||
# choose attention backend based on use_mla and use_torchair
|
||||
backend_map = {
|
||||
(True, True):
|
||||
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend",
|
||||
(True, False):
|
||||
"vllm_ascend.attention.mla_v1.AscendMLABackend",
|
||||
(False, True):
|
||||
"vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend",
|
||||
(False, False):
|
||||
"vllm_ascend.attention.attention_v1.AscendAttentionBackend"
|
||||
}
|
||||
return backend_map[(use_mla, use_torchair)]
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
return "vllm_ascend.lora.punica_wrapper.punica_npu.PunicaWrapperNPU"
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(cls,
|
||||
device: Optional[torch.types.Device] = None
|
||||
) -> float:
|
||||
torch.npu.reset_peak_memory_stats(device)
|
||||
return torch.npu.max_memory_allocated(device)
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
return "vllm_ascend.distributed.communicator.NPUCommunicator"
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(cls):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_v1(cls, model_config: ModelConfig) -> bool:
|
||||
"""Returns whether the current platform can support v1 for the supplied
|
||||
model configuration.
|
||||
"""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_static_graph_wrapper_cls(cls) -> str:
|
||||
"""
|
||||
Get piecewise backend class for piecewise graph.
|
||||
"""
|
||||
return "vllm_ascend.compilation.acl_graph.ACLGraphWrapper" # noqa
|
||||
|
||||
@classmethod
|
||||
def stateless_init_device_torch_dist_pg(
|
||||
cls,
|
||||
backend: str,
|
||||
prefix_store: PrefixStore,
|
||||
group_rank: int,
|
||||
group_size: int,
|
||||
timeout: timedelta,
|
||||
) -> ProcessGroup:
|
||||
from torch.distributed import is_hccl_available
|
||||
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
|
||||
|
||||
assert is_hccl_available()
|
||||
|
||||
pg: ProcessGroup = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
)
|
||||
|
||||
backend_options = ProcessGroupHCCL.Options()
|
||||
backend_options._timeout = timeout
|
||||
|
||||
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
|
||||
backend_options)
|
||||
device = torch.device("npu")
|
||||
# TODO(Yizhou): Like we mentioned above, _set_default_backend is not
|
||||
# implemented in the 2.5.1 version of PyTorch. But we need to set it
|
||||
# after the latest version is released.
|
||||
# pg._set_default_backend(backend_type)
|
||||
backend_class._set_sequence_number_for_group()
|
||||
backend_type = ProcessGroup.BackendType.CUSTOM
|
||||
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
return pg
|
||||
0
vllm_ascend/quantization/__init__.py
Normal file
0
vllm_ascend/quantization/__init__.py
Normal file
184
vllm_ascend/quantization/func_wrapper.py
Normal file
184
vllm_ascend/quantization/func_wrapper.py
Normal file
@@ -0,0 +1,184 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, QuantizationConfig)
|
||||
|
||||
|
||||
# func refers to vocabParallelEmbedding.__init__
|
||||
def wrapper_vocab_parallel_embedding_init(func):
|
||||
|
||||
def init(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
func(
|
||||
self,
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
params_dtype,
|
||||
org_num_embeddings,
|
||||
padding_size,
|
||||
quant_config,
|
||||
prefix,
|
||||
)
|
||||
# TODO: Contact vLLM maintainers to add a `params_dtype` attribute to the `VocabParallelEmbedding` class.
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
|
||||
return init
|
||||
|
||||
|
||||
# func refers to RMSNorm.__init__
|
||||
def wrapper_rmsnorm_init(func):
|
||||
|
||||
def init(self, hidden_size: int, **extra_args) -> None:
|
||||
func(self, hidden_size, **extra_args)
|
||||
self.ignore_anti = True
|
||||
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
|
||||
requires_grad=False)
|
||||
|
||||
return init
|
||||
|
||||
|
||||
# func refers to RMSNorm.forward_oot
|
||||
def wrapper_rmsnorm_forward_oot(func):
|
||||
|
||||
def _rmsnorm_forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if not self.ignore_anti:
|
||||
if residual is not None:
|
||||
residual += x
|
||||
out = torch_npu._npu_quant_rms_norm(
|
||||
residual,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.input_scale,
|
||||
self.input_offset,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return out, residual
|
||||
out = torch_npu._npu_quant_rms_norm(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.input_scale,
|
||||
self.input_offset,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return out
|
||||
|
||||
if residual is not None:
|
||||
x, residual = func(self, x, residual)
|
||||
return x.add_(self.bias), residual
|
||||
|
||||
return func(self, x).add_(self.bias)
|
||||
|
||||
return _rmsnorm_forward_oot
|
||||
|
||||
|
||||
MODEL_LAYER_MAPPING = {
|
||||
"LlamaModel": {
|
||||
"attn": {
|
||||
"layer_attr": "self_attn",
|
||||
"proj_attr": "qkv_proj",
|
||||
"norm_attr": "input_layernorm",
|
||||
"unquantized_type": UnquantizedLinearMethod,
|
||||
},
|
||||
"mlp": {
|
||||
"layer_attr": "mlp",
|
||||
"proj_attr": "gate_up_proj",
|
||||
"norm_attr": "post_attention_layernorm",
|
||||
"unquantized_type": UnquantizedLinearMethod,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def wrapper_load_model(func):
|
||||
|
||||
def postprocess_loading(self) -> None:
|
||||
func(self)
|
||||
|
||||
def process_layer(layer, idx, mapping):
|
||||
|
||||
def process_module(module_cfg, layer_obj):
|
||||
if module_cfg is None:
|
||||
return
|
||||
|
||||
module_obj = getattr(layer_obj, module_cfg["layer_attr"], None)
|
||||
if module_obj is None:
|
||||
return
|
||||
|
||||
proj_attr = module_cfg["proj_attr"]
|
||||
if callable(proj_attr):
|
||||
proj = proj_attr(module_obj, idx)
|
||||
else:
|
||||
proj = getattr(module_obj, proj_attr, None)
|
||||
|
||||
norm = getattr(layer_obj, module_cfg["norm_attr"], None)
|
||||
|
||||
if proj is None or norm is None:
|
||||
return
|
||||
|
||||
norm.ignore_anti = isinstance(proj.quant_method,
|
||||
module_cfg["unquantized_type"])
|
||||
if not norm.ignore_anti:
|
||||
for param_name in ["input_scale", "input_offset"]:
|
||||
if hasattr(proj, param_name):
|
||||
param = getattr(proj, param_name)
|
||||
norm.register_parameter(
|
||||
param_name,
|
||||
torch.nn.Parameter(param.clone(),
|
||||
requires_grad=False))
|
||||
|
||||
process_module(mapping.get("attn"), layer)
|
||||
process_module(mapping.get("mlp"), layer)
|
||||
|
||||
model_type = self.model.model.__class__.__name__
|
||||
mapping = MODEL_LAYER_MAPPING.get(model_type)
|
||||
|
||||
if not mapping:
|
||||
logger.info(
|
||||
f"Warning: Model type '{model_type}' not found in MODEL_LAYER_MAPPING. Skipping layer mapping."
|
||||
)
|
||||
return
|
||||
|
||||
for idx, layer in enumerate(self.model.model.layers):
|
||||
process_layer(layer, idx, mapping)
|
||||
|
||||
if isinstance(self.model.model.norm, RMSNorm):
|
||||
self.model.model.norm.ignore_anti = True
|
||||
|
||||
return postprocess_loading
|
||||
357
vllm_ascend/quantization/quant_config.py
Normal file
357
vllm_ascend/quantization/quant_config.py
Normal file
@@ -0,0 +1,357 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional
|
||||
|
||||
import torch
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import \
|
||||
register_quantization_config
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
UnquantizedEmbeddingMethod, VocabParallelEmbedding)
|
||||
from vllm.model_executor.parameter import PerTensorScaleParameter
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from vllm_ascend.ops.fused_moe import AscendUnquantizedFusedMoEMethod
|
||||
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
|
||||
|
||||
from .quantizer import AscendQuantizer
|
||||
|
||||
|
||||
@register_quantization_config(ASCEND_QUANTIZATION_METHOD)
|
||||
class AscendQuantConfig(QuantizationConfig):
|
||||
"""Config class for Ascend
|
||||
|
||||
This class is a general class that parse quantization configs
|
||||
that are supported on ascend hardware.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Dict[str, Any]):
|
||||
self.quant_description = quant_config
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "AscendQuantConfig:\n" + super().__repr__()
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return ASCEND_QUANTIZATION_METHOD
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.int8, torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise NotImplementedError(
|
||||
"Ascend hardware dose not support \"get_min_capability\" feature.")
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quant_model_description.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "AscendQuantConfig":
|
||||
return cls(config)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
if torch.npu.is_available():
|
||||
return ASCEND_QUANTIZATION_METHOD
|
||||
return None
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.is_layer_skipped_ascend(prefix,
|
||||
self.packed_modules_mapping):
|
||||
return UnquantizedLinearMethod()
|
||||
return AscendLinearMethod(self, prefix,
|
||||
self.packed_modules_mapping)
|
||||
elif isinstance(layer, Attention) and \
|
||||
'fa_quant_type' in self.quant_description.keys() and \
|
||||
self.quant_description['fa_quant_type'] is not None:
|
||||
return AscendKVCacheMethod(self, prefix)
|
||||
elif isinstance(layer, Attention) and self.quant_description.get(
|
||||
'kv_quant_type') == 'C8':
|
||||
return AscendKVCacheMethod(self, prefix)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
if self.is_layer_skipped_ascend(prefix,
|
||||
self.packed_modules_mapping):
|
||||
return AscendUnquantizedFusedMoEMethod(layer.moe_config)
|
||||
return AscendFusedMoEMethod(self, prefix,
|
||||
self.packed_modules_mapping)
|
||||
elif isinstance(layer, VocabParallelEmbedding):
|
||||
if self.is_layer_skipped_ascend(prefix,
|
||||
self.packed_modules_mapping):
|
||||
return UnquantizedEmbeddingMethod()
|
||||
return AscendEmbeddingMethod(self, prefix,
|
||||
self.packed_modules_mapping)
|
||||
return None
|
||||
|
||||
def is_layer_skipped_ascend(
|
||||
self,
|
||||
prefix: str,
|
||||
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})):
|
||||
# adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped
|
||||
proj_name = prefix.split(".")[-1]
|
||||
if proj_name in fused_mapping:
|
||||
shard_prefixes = [
|
||||
prefix.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in fused_mapping[proj_name]
|
||||
]
|
||||
|
||||
is_skipped = None
|
||||
for shard_prefix in shard_prefixes:
|
||||
is_shard_skipped = self.quant_description[shard_prefix +
|
||||
'.weight'] == "FLOAT"
|
||||
|
||||
if is_skipped is None:
|
||||
is_skipped = is_shard_skipped
|
||||
elif is_shard_skipped != is_skipped:
|
||||
raise ValueError(
|
||||
f"Detected some but not all shards of {prefix} "
|
||||
"are quantized. All shards of fused layers "
|
||||
"to have the same precision.")
|
||||
else:
|
||||
is_skipped = self.quant_description[prefix + '.weight'] == "FLOAT"
|
||||
|
||||
assert is_skipped is not None
|
||||
return is_skipped
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class AscendLinearMethod(LinearMethodBase):
|
||||
"""Linear method for Ascend quantization.
|
||||
|
||||
This class calls AscendQuantizer to search a specific quantization
|
||||
implementations supported on ascend hardware for linear methods.
|
||||
|
||||
Args:
|
||||
quant_config: The Ascend quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
||||
packed_modules_mapping: Dict[str, Any]) -> None:
|
||||
self.quantizer = AscendQuantizer.get_quantizer(
|
||||
quant_config.quant_description, prefix, packed_modules_mapping)
|
||||
self.quant_method = self.quantizer.build_linear_method()
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
weight_dict = self.quant_method.get_weight(input_size_per_partition,
|
||||
output_size_per_partition,
|
||||
params_dtype)
|
||||
for weight_name, weight_param in weight_dict.items():
|
||||
param = torch.nn.Parameter(weight_param, requires_grad=False)
|
||||
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
|
||||
layer.register_parameter(weight_name, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
|
||||
for pertensor_name, pertensor_param in pertensor_dict.items():
|
||||
param = PerTensorScaleParameter(data=pertensor_param,
|
||||
weight_loader=weight_loader)
|
||||
# disable warning
|
||||
param.ignore_warning = True
|
||||
layer.register_parameter(pertensor_name, param)
|
||||
|
||||
perchannel_dict = self.quant_method.get_perchannel_param(
|
||||
output_size_per_partition, params_dtype)
|
||||
for perchannel_name, perchannel_param in perchannel_dict.items():
|
||||
param = torch.nn.Parameter(perchannel_param, requires_grad=False)
|
||||
set_weight_attrs(param, {"output_dim": 0})
|
||||
layer.register_parameter(perchannel_name, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
pergroup_dict = self.quant_method.get_pergroup_param(
|
||||
input_size_per_partition, output_size_per_partition, params_dtype)
|
||||
for pergroup_name, pergroup_param in pergroup_dict.items():
|
||||
param = torch.nn.Parameter(pergroup_param, requires_grad=False)
|
||||
set_weight_attrs(param, {"output_dim": 0})
|
||||
layer.register_parameter(pergroup_name, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name:
|
||||
setattr(param, "input_dim", 1)
|
||||
param.input_dim = 1
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(layer, RowParallelLinear):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
return self.quant_method.apply(layer, x, bias, tp_rank)
|
||||
return self.quant_method.apply(layer, x, bias)
|
||||
|
||||
|
||||
class AscendKVCacheMethod(BaseKVCacheMethod):
|
||||
"""KVCache method for Ascend quantization.
|
||||
|
||||
This class calls AscendQuantizer to search a specific quantization
|
||||
implementations supported on ascend hardware for kvcache methods.
|
||||
|
||||
Args:
|
||||
quant_config: The Ascend quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None:
|
||||
self.quantizer = AscendQuantizer.get_quantizer(
|
||||
quant_config.quant_description, prefix)
|
||||
self.quant_method = self.quantizer.build_attention_method()
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module) -> None:
|
||||
# Different from linear method, there are no weight processing/slicing
|
||||
# steps for attention in vllm. So the whole process of create weights
|
||||
# is hidden into the specific quant method.
|
||||
self.quant_method.create_weights(layer)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
|
||||
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
|
||||
key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata,
|
||||
attn_type, scale, output) -> torch.Tensor:
|
||||
return self.quant_method.apply(layer, query, key, value, kv_cache,
|
||||
attn_metadata, attn_type, scale, output)
|
||||
|
||||
|
||||
class AscendFusedMoEMethod(FusedMoEMethodBase):
|
||||
"""FusedMoE method for Ascend quantization.
|
||||
|
||||
This class calls AscendQuantizer to search a specific quantization
|
||||
implementations supported on ascend hardware for kvcache methods.
|
||||
|
||||
Args:
|
||||
quant_config: The Ascend quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
||||
packed_modules_mapping: Dict[str, Any]):
|
||||
self.quantizer = AscendQuantizer.get_quantizer(
|
||||
quant_config.quant_description, prefix, packed_modules_mapping)
|
||||
self.quant_method = self.quantizer.build_moe_method()
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
weight_param = self.quant_method.get_weight(
|
||||
num_experts, intermediate_size_per_partition, hidden_size,
|
||||
params_dtype)
|
||||
for param_key, param_value in weight_param.items():
|
||||
param = torch.nn.Parameter(param_value, requires_grad=False)
|
||||
layer.register_parameter(param_key, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
|
||||
per_group_param = [
|
||||
"weight_scale_second", "weight_offset_second", "scale_bias"
|
||||
]
|
||||
dynamic_quant_param = self.quant_method.get_dynamic_quant_param(
|
||||
num_experts, intermediate_size_per_partition, hidden_size,
|
||||
params_dtype)
|
||||
for param_key, param_value in dynamic_quant_param.items():
|
||||
param = torch.nn.Parameter(param_value, requires_grad=False)
|
||||
layer.register_parameter(param_key, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
if any(fields in param_key for fields in per_group_param):
|
||||
setattr(param, "quant_method",
|
||||
FusedMoeWeightScaleSupported.GROUP.value)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = False,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num=0,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.quant_method.apply(
|
||||
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
|
||||
global_num_experts, expert_map, topk_group, num_expert_group,
|
||||
custom_routing_function, scoring_func, e_score_correction_bias,
|
||||
is_prefill, enable_force_load_balance, log2phy,
|
||||
global_redundant_expert_num, **kwargs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
|
||||
|
||||
class AscendEmbeddingMethod(AscendLinearMethod):
|
||||
"""Embedding method for Ascend quantization.
|
||||
This class calls AscendQuantizer to search a specific quantization
|
||||
implementations supported on ascend hardware for Embedding methods.
|
||||
Args:
|
||||
quant_config: The Ascend quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
||||
packed_modules_mapping: Dict[str, Any]) -> None:
|
||||
self.quantizer = AscendQuantizer.get_quantizer(
|
||||
quant_config.quant_description, prefix, packed_modules_mapping)
|
||||
self.quant_method = self.quantizer.build_linear_method()
|
||||
311
vllm_ascend/quantization/quantizer.py
Normal file
311
vllm_ascend/quantization/quantizer.py
Normal file
@@ -0,0 +1,311 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from vllm.logger import logger
|
||||
|
||||
from .func_wrapper import (wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init,
|
||||
wrapper_vocab_parallel_embedding_init)
|
||||
from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod,
|
||||
AscendW4A8DynamicLinearMethod)
|
||||
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
|
||||
AscendW8A8LinearMethod)
|
||||
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
|
||||
AscendW8A8DynamicLinearMethod)
|
||||
|
||||
CUSTOMIZED_QUANTIZER_TYPE: List[str] = []
|
||||
|
||||
|
||||
class AscendQuantizer:
|
||||
"""An interface to different quantization implementations for ascend hardwares."""
|
||||
|
||||
@classmethod
|
||||
def get_quantizer(cls,
|
||||
quant_config: Dict[str, Any],
|
||||
prefix: str,
|
||||
packed_modules_mapping: Optional[Dict[str,
|
||||
Any]] = dict()):
|
||||
# TODO: Need a param to choose quantization algorithms.
|
||||
quantization_algorithm = ''
|
||||
|
||||
if quantization_algorithm in CUSTOMIZED_QUANTIZER_TYPE:
|
||||
return
|
||||
|
||||
return VLLMAscendQuantizer.get_quantizer(quant_config, prefix,
|
||||
packed_modules_mapping)
|
||||
|
||||
def build_linear_method(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def build_moe_method(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def build_attention_method(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class VLLMAscendQuantizer:
|
||||
_instance: Optional[object] = None
|
||||
patched = False
|
||||
|
||||
def __init__(self, quant_description):
|
||||
if VLLMAscendQuantizer.patched:
|
||||
return
|
||||
for name in quant_description.keys():
|
||||
if "norm.bias" in name:
|
||||
VLLMAscendQuantizer.apply_patch(
|
||||
"vllm.model_executor.layers.layernorm.RMSNorm", "__init__",
|
||||
[wrapper_rmsnorm_init])
|
||||
VLLMAscendQuantizer.apply_patch(
|
||||
"vllm_ascend.ops.layernorm.AscendRMSNorm", "forward_oot",
|
||||
[wrapper_rmsnorm_forward_oot])
|
||||
VLLMAscendQuantizer.apply_patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding",
|
||||
"__init__", [wrapper_vocab_parallel_embedding_init])
|
||||
break
|
||||
VLLMAscendQuantizer.patched = True
|
||||
logger.info("Using the vLLM Ascend Quantizer version now!")
|
||||
|
||||
@staticmethod
|
||||
def apply_patch(target_module, target_function, wrappers):
|
||||
|
||||
original_module, original_function = VLLMAscendQuantizer.parse_path(
|
||||
target_module, target_function, False)
|
||||
|
||||
original_function_id = id(original_function)
|
||||
|
||||
candidate = original_function
|
||||
for wrapper in wrappers:
|
||||
candidate = wrapper(candidate)
|
||||
if target_function is not None:
|
||||
setattr(original_module, target_function, candidate)
|
||||
|
||||
for _, value in sys.modules.copy().items():
|
||||
if target_function is None:
|
||||
continue
|
||||
try:
|
||||
attr = getattr(value, target_function, None)
|
||||
if attr is not None and id(attr) == original_function_id:
|
||||
setattr(value, target_function, candidate)
|
||||
except ImportError:
|
||||
continue
|
||||
|
||||
@staticmethod
|
||||
def parse_path(module_path, function_name, create_dummy):
|
||||
"""
|
||||
Parse module path and resolve/create modules as needed.
|
||||
|
||||
Args:
|
||||
module_path: Dot-separated module path
|
||||
function_name: Target function name (None for module only)
|
||||
create_dummy: Create dummy modules/functions when missing
|
||||
|
||||
Returns:
|
||||
Tuple of (resolved module, target function/none)
|
||||
|
||||
Raises:
|
||||
ModuleNotFoundError: If module path is invalid and create_dummy=False
|
||||
AttributeError: If function is missing and create_dummy=False
|
||||
"""
|
||||
from importlib.machinery import ModuleSpec
|
||||
|
||||
def create_dummy_module(full_path, parent=None):
|
||||
"""Create and register a placeholder module"""
|
||||
dummy = types.ModuleType(full_path)
|
||||
dummy.__file__ = "vllm_ascend.dummy_module.py"
|
||||
dummy.__spec__ = ModuleSpec(full_path, None)
|
||||
sys.modules[full_path] = dummy
|
||||
if parent:
|
||||
setattr(parent, full_path.split(".")[-1], dummy)
|
||||
return dummy
|
||||
|
||||
def create_placeholder_function(func_name):
|
||||
"""Create dummy function that raises when called"""
|
||||
|
||||
def placeholder(*args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
f"Function {func_name} is a placeholder")
|
||||
|
||||
placeholder.__name__ = func_name
|
||||
return placeholder
|
||||
|
||||
modules = module_path.split(".")
|
||||
current_module = None
|
||||
processed_path = []
|
||||
|
||||
for idx, part in enumerate(modules):
|
||||
current_path = ".".join(modules[:idx + 1])
|
||||
parent_path = ".".join(modules[:idx]) if idx > 0 else None
|
||||
|
||||
try:
|
||||
current_module = importlib.import_module(current_path)
|
||||
except ModuleNotFoundError:
|
||||
# Handle missing module
|
||||
parent = importlib.import_module(
|
||||
parent_path) if parent_path else None
|
||||
if parent and hasattr(parent, part):
|
||||
# Use existing attribute from parent
|
||||
current_module = getattr(parent, part)
|
||||
# Check for early function resolution
|
||||
if function_name and hasattr(current_module,
|
||||
function_name):
|
||||
return current_module, getattr(current_module,
|
||||
function_name)
|
||||
if function_name and create_dummy:
|
||||
ph_func = create_placeholder_function(function_name)
|
||||
setattr(current_module, function_name, ph_func)
|
||||
return current_module, ph_func
|
||||
if function_name:
|
||||
raise AttributeError(
|
||||
f"Function {function_name} missing in {current_path}"
|
||||
)
|
||||
else:
|
||||
if not create_dummy:
|
||||
raise
|
||||
# Create and register dummy module
|
||||
current_module = create_dummy_module(
|
||||
current_path,
|
||||
parent=importlib.import_module(parent_path)
|
||||
if parent_path else None)
|
||||
|
||||
processed_path.append(part)
|
||||
|
||||
# Final function handling
|
||||
final_module = sys.modules[module_path]
|
||||
if function_name is not None:
|
||||
if not hasattr(final_module, function_name):
|
||||
if create_dummy:
|
||||
ph_func = create_placeholder_function(function_name)
|
||||
setattr(final_module, function_name, ph_func)
|
||||
else:
|
||||
setattr(final_module, function_name, None)
|
||||
return final_module, getattr(final_module, function_name)
|
||||
|
||||
return final_module, None
|
||||
|
||||
@staticmethod
|
||||
def build_linear_method():
|
||||
raise NotImplementedError(
|
||||
"Linear method is not implemented for the current quant type.")
|
||||
|
||||
@staticmethod
|
||||
def build_moe_method():
|
||||
raise NotImplementedError(
|
||||
"MoE method is not implemented for the current quant type.")
|
||||
|
||||
@staticmethod
|
||||
def build_attention_method():
|
||||
raise NotImplementedError(
|
||||
"Attention method is not implemented for the current quant type.")
|
||||
|
||||
@staticmethod
|
||||
def get_linear_quant_type(quant_description: Dict[str, Any], prefix: str,
|
||||
packed_modules_mapping: Dict[str, Any]):
|
||||
proj_name = prefix.split(".")[-1]
|
||||
if proj_name in packed_modules_mapping:
|
||||
quant_type = None
|
||||
shard_prefixes = [
|
||||
prefix.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in packed_modules_mapping[proj_name]
|
||||
]
|
||||
for shard_prefix in shard_prefixes:
|
||||
shard_quant_type = quant_description[shard_prefix + '.weight']
|
||||
|
||||
if quant_type is None:
|
||||
quant_type = shard_quant_type
|
||||
elif shard_quant_type != quant_type:
|
||||
raise ValueError(
|
||||
f"Not all shards of {prefix} are quantized with same quant type."
|
||||
f"Shard {proj_name} uses {shard_quant_type}, but another shard"
|
||||
f"use {quant_type}. Please check quantization config.")
|
||||
else:
|
||||
quant_type = quant_description[prefix + '.weight']
|
||||
return quant_type
|
||||
|
||||
@classmethod
|
||||
def get_quantizer(cls,
|
||||
quant_description: Dict[str, Any],
|
||||
prefix: str,
|
||||
packed_modules_mapping: Optional[Dict[str, Any]] = None):
|
||||
if packed_modules_mapping is None:
|
||||
packed_modules_mapping = dict()
|
||||
# Attention
|
||||
if '.attn' in prefix and 'fa_quant_type' in quant_description.keys():
|
||||
quant_type = quant_description['fa_quant_type']
|
||||
# Use KVCache int8
|
||||
elif '.attn' in prefix and 'kv_quant_type' in quant_description.keys():
|
||||
quant_type = quant_description['kv_quant_type']
|
||||
# Linear
|
||||
else:
|
||||
quant_type = cls.get_linear_quant_type(quant_description, prefix,
|
||||
packed_modules_mapping)
|
||||
if quant_type in SUPPORT_ASCEND_QUANTIZER_TYPE.keys():
|
||||
cls = SUPPORT_ASCEND_QUANTIZER_TYPE[quant_type]
|
||||
if not cls._instance:
|
||||
cls._instance = cls(quant_description)
|
||||
return cls._instance
|
||||
raise NotImplementedError("Currently, vLLM Ascend only supports following quant types:" \
|
||||
f"{list(SUPPORT_ASCEND_QUANTIZER_TYPE.keys())}")
|
||||
|
||||
|
||||
class W4A8DYNAMICQuantizer(VLLMAscendQuantizer):
|
||||
|
||||
@staticmethod
|
||||
def build_linear_method():
|
||||
return AscendW4A8DynamicLinearMethod()
|
||||
|
||||
@staticmethod
|
||||
def build_moe_method():
|
||||
return AscendW4A8DynamicFusedMoEMethod()
|
||||
|
||||
|
||||
class W8A8Quantizer(VLLMAscendQuantizer):
|
||||
|
||||
@staticmethod
|
||||
def build_linear_method():
|
||||
return AscendW8A8LinearMethod()
|
||||
|
||||
@staticmethod
|
||||
def build_moe_method():
|
||||
return AscendW8A8FusedMoEMethod()
|
||||
|
||||
@staticmethod
|
||||
def build_attention_method():
|
||||
return AscendC8KVCacheMethod()
|
||||
|
||||
|
||||
class W8A8DYNAMICQuantizer(VLLMAscendQuantizer):
|
||||
|
||||
@staticmethod
|
||||
def build_linear_method():
|
||||
return AscendW8A8DynamicLinearMethod()
|
||||
|
||||
@staticmethod
|
||||
def build_moe_method():
|
||||
return AscendW8A8DynamicFusedMoEMethod()
|
||||
|
||||
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE = {
|
||||
"W4A8_DYNAMIC": W4A8DYNAMICQuantizer,
|
||||
"W8A8": W8A8Quantizer,
|
||||
"W8A8_DYNAMIC": W8A8DYNAMICQuantizer,
|
||||
"C8": W8A8Quantizer,
|
||||
}
|
||||
394
vllm_ascend/quantization/w4a8_dynamic.py
Normal file
394
vllm_ascend/quantization/w4a8_dynamic.py
Normal file
@@ -0,0 +1,394 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.fused_moe import unified_fused_experts_eager
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
|
||||
|
||||
class AscendW4A8DynamicLinearMethod:
|
||||
"""Linear method for Ascend W4A8_DYNAMIC
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
try:
|
||||
self.group_size = get_current_vllm_config(
|
||||
).quant_config.quant_description.get("group_size", 256)
|
||||
except AttributeError:
|
||||
self.group_size = 256
|
||||
|
||||
@staticmethod
|
||||
def get_weight(input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
params_dict = {
|
||||
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
|
||||
}
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def get_perchannel_param(output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def get_pergroup_param(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
params_dict["weight_scale"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_offset"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_scale_second"] = torch.empty(output_size,
|
||||
input_size //
|
||||
self.group_size,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_offset_second"] = torch.empty(output_size,
|
||||
input_size //
|
||||
self.group_size,
|
||||
dtype=params_dtype)
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def process_scale_second(weight: torch.Tensor, scale: torch.Tensor,
|
||||
per_group_scale: torch.Tensor):
|
||||
k, n = weight.shape
|
||||
group_num, n = per_group_scale.shape
|
||||
weight_high = weight.to(torch.float32).reshape(
|
||||
group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
|
||||
weight_high = weight_high.reshape(k, n)
|
||||
bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0)
|
||||
antiquant_scale = (scale * per_group_scale).reshape(group_num, n)
|
||||
return antiquant_scale.npu(), bias
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch_npu.npu_weight_quant_batchmatmul(
|
||||
x,
|
||||
layer.weight,
|
||||
antiquant_scale=layer.weight_scale_second.to(x.dtype),
|
||||
antiquant_group_size=self.group_size,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
layer.weight_scale.data = layer.weight_scale.data.flatten().to(
|
||||
torch.float32)
|
||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||
layer.weight_scale_second.data, scale_bias = self.process_scale_second(
|
||||
layer.weight.data,
|
||||
layer.weight_scale.data,
|
||||
layer.weight_scale_second.data.transpose(0, 1).contiguous(),
|
||||
)
|
||||
param = torch.nn.Parameter(scale_bias, requires_grad=False)
|
||||
layer.register_parameter("weight_scale_bias", param)
|
||||
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
|
||||
layer.weight.data.to(torch.int32))
|
||||
|
||||
|
||||
class AscendW4A8DynamicFusedMoEMethod:
|
||||
"""FusedMoe method for Ascend W4A8_DYNAMIC.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
|
||||
self.ep_group = get_ep_group()
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.group_size = vllm_config.quant_config.quant_description.get(
|
||||
"group_size", 256)
|
||||
quant_version = vllm_config.quant_config.quant_description.get(
|
||||
"version", "0")
|
||||
# NOTE: new quantize weights: 2 int4 pack into int8
|
||||
self.new_quant_version = quant_version == "1.0.0"
|
||||
self.tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else self.ep_group.world_size
|
||||
if self.new_quant_version and self.tp_size > 16:
|
||||
raise ValueError(
|
||||
"The current weight does not support moe part tp>16.")
|
||||
|
||||
try:
|
||||
device_group = get_mc2_group().device_group
|
||||
# TODO: Try local_rank = ep_group.rank_in_group
|
||||
local_rank = torch.distributed.get_rank(group=device_group)
|
||||
backend = device_group._get_backend(torch.device("npu"))
|
||||
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
|
||||
local_rank)
|
||||
except AttributeError:
|
||||
self.moe_all_to_all_group_name = ""
|
||||
|
||||
def get_weight(self, num_experts: int,
|
||||
intermediate_size_per_partition: int, hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
if self.new_quant_version:
|
||||
w13_output_size = intermediate_size_per_partition
|
||||
w2_output_size = hidden_sizes // 2
|
||||
else:
|
||||
w13_output_size = 2 * intermediate_size_per_partition
|
||||
w2_output_size = hidden_sizes
|
||||
|
||||
param_dict["w13_weight"] = torch.empty(num_experts,
|
||||
w13_output_size,
|
||||
hidden_sizes,
|
||||
dtype=torch.int8)
|
||||
param_dict["w2_weight"] = torch.empty(num_experts,
|
||||
w2_output_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int8)
|
||||
return param_dict
|
||||
|
||||
def get_dynamic_quant_param(self, num_experts: int,
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
param_dict["w13_weight_scale"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
|
||||
param_dict["w13_weight_offset"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
|
||||
param_dict["w13_weight_scale_second"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=params_dtype)
|
||||
|
||||
param_dict["w13_weight_offset_second"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=params_dtype)
|
||||
|
||||
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
param_dict["w2_weight_scale_second"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=params_dtype)
|
||||
param_dict["w2_weight_offset_second"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=params_dtype)
|
||||
|
||||
if self.new_quant_version:
|
||||
param_dict["w13_scale_bias"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_scale_bias"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
16 // self.tp_size,
|
||||
dtype=torch.float32)
|
||||
|
||||
return param_dict
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = True,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
topk_weights, topk_ids, row_idx = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
fused_moe_state = get_forward_context().fused_moe_state
|
||||
shared_gate_up, shared_dequant_scale = None, None
|
||||
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
|
||||
share_up_out, _ = shared_experts.gate_up_proj(
|
||||
(quantized_x_for_share, dynamic_scale_for_share))
|
||||
shared_gate_up, shared_dequant_scale = share_up_out[
|
||||
0], share_up_out[1]
|
||||
|
||||
# this is a naive implementation for experts load balance so as
|
||||
# to avoid accumulating too much tokens on a single rank.
|
||||
# currently it is only activated when doing profile runs.
|
||||
if enable_force_load_balance:
|
||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
|
||||
return unified_fused_experts_eager(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_scale=layer.w13_weight_scale_second,
|
||||
w2_scale=layer.w2_weight_scale_second,
|
||||
w1_scale_bias=layer.w13_scale_bias,
|
||||
w2_scale_bias=layer.w2_scale_bias,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
shared_experts=shared_experts,
|
||||
shared_gate_up=shared_gate_up,
|
||||
shared_dequant_scale=shared_dequant_scale,
|
||||
mc2_mask=kwargs.get("mc2_mask", None),
|
||||
with_quant=True)
|
||||
|
||||
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
|
||||
group_num, k, n = weight.shape
|
||||
# the weight of the new version is reduced by half by pack n, so it needs to be restored
|
||||
if self.new_quant_version:
|
||||
n = n * 2
|
||||
per_group_scale = per_group_scale.reshape(group_num, -1, n)
|
||||
group_num, quantgroup_num, n = per_group_scale.shape
|
||||
bias = None
|
||||
if not self.new_quant_version:
|
||||
weight_high = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \
|
||||
per_group_scale.reshape([group_num, quantgroup_num, 1, n])
|
||||
weight_high = weight_high.reshape([group_num, k, n])
|
||||
bias = 8 * (weight_high.to(torch.float32) * scale).sum(axis=1)
|
||||
scale_fp32 = (scale * per_group_scale).to(torch.float16).to(
|
||||
torch.float32)
|
||||
scale_fp32_np = scale_fp32.cpu().numpy()
|
||||
scale_fp32_np.dtype = np.uint32
|
||||
sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2),
|
||||
dtype=np.uint32)
|
||||
|
||||
sscale_uint64[..., ::2] = scale_fp32_np
|
||||
|
||||
sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(),
|
||||
dtype=np.int64).copy()
|
||||
sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape(
|
||||
group_num, quantgroup_num, n)
|
||||
sscale_uint64_tensor = sscale_uint64_tensor.npu()
|
||||
return sscale_uint64_tensor, bias
|
||||
|
||||
def update_bias(self, layer, w13_bias, w2_bias):
|
||||
if self.new_quant_version:
|
||||
layer.w13_scale_bias.data = layer.w13_scale_bias.data.transpose(
|
||||
1, 2).contiguous().sum(axis=1)
|
||||
layer.w2_scale_bias.data = layer.w2_scale_bias.data.transpose(
|
||||
1, 2).contiguous().sum(axis=1)
|
||||
else:
|
||||
w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
|
||||
layer.register_parameter("w13_scale_bias", w13_scale_bias)
|
||||
w2_scale_bias = torch.nn.Parameter(w2_bias, requires_grad=False)
|
||||
layer.register_parameter("w2_scale_bias", w2_scale_bias)
|
||||
|
||||
def pack_to_int32(self, weight: torch.Tensor):
|
||||
if self.new_quant_version:
|
||||
group_num, k, n = weight.shape
|
||||
assert n % 4 == 0, "the last dim of weight needs to be divided by 4"
|
||||
packed_n = n // 4
|
||||
# pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4
|
||||
packed_weight = torch.from_numpy(
|
||||
np.frombuffer(weight.cpu().numpy().tobytes(), dtype=np.int32))
|
||||
return packed_weight.reshape(group_num, k, packed_n).npu()
|
||||
else:
|
||||
return torch_npu.npu_quantize(weight.to(torch.float32),
|
||||
torch.tensor([1.]).npu(), None,
|
||||
torch.quint4x2, -1, False)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.transpose_weight:
|
||||
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w13_weight_scale_second.data = layer.w13_weight_scale_second.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight_scale_second.data = layer.w2_weight_scale_second.data.transpose(
|
||||
1, 2).contiguous()
|
||||
|
||||
layer.w13_weight_scale_second.data, w13_bias = self.process_scale(
|
||||
layer.w13_weight, layer.w13_weight_scale.data,
|
||||
layer.w13_weight_scale_second.data)
|
||||
layer.w2_weight_scale_second.data, w2_bias = self.process_scale(
|
||||
layer.w2_weight, layer.w2_weight_scale.data,
|
||||
layer.w2_weight_scale_second.data)
|
||||
|
||||
self.update_bias(layer, w13_bias, w2_bias)
|
||||
|
||||
layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
|
||||
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)
|
||||
647
vllm_ascend/quantization/w8a8.py
Normal file
647
vllm_ascend/quantization/w8a8.py
Normal file
@@ -0,0 +1,647 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
|
||||
|
||||
|
||||
def quant_per_tensor(in_tensor: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
input_offset: torch.Tensor,
|
||||
function=False):
|
||||
return torch_npu.npu_quantize(in_tensor, input_scale, input_offset,
|
||||
torch.qint8, -1, function)
|
||||
|
||||
|
||||
class AscendW8A8LinearMethod:
|
||||
"""Linear method for Ascend W8A8.
|
||||
|
||||
Args:
|
||||
w_sym: whether the linear weight is symmetrically quantized.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# aclnn quant matmul requires to transpose matrix B, set to true by default.
|
||||
self.transpose_weight = not is_310p()
|
||||
|
||||
@staticmethod
|
||||
def get_weight(
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype = torch.bfloat16,
|
||||
) -> Dict[str, Any]:
|
||||
params_dict = {
|
||||
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
|
||||
}
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
|
||||
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_perchannel_param(
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32)
|
||||
if params_dtype == torch.bfloat16:
|
||||
params_dict["deq_scale"] = torch.empty(output_size,
|
||||
dtype=torch.float32)
|
||||
elif params_dtype == torch.float16:
|
||||
params_dict["deq_scale"] = torch.empty(output_size,
|
||||
dtype=torch.int64)
|
||||
params_dict["weight_scale"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_offset"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
return params_dict
|
||||
|
||||
def get_pergroup_param(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def apply(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = 0,
|
||||
) -> torch.Tensor:
|
||||
if x.dtype != torch.int8:
|
||||
x = quant_per_tensor(
|
||||
x,
|
||||
layer.aclnn_input_scale_reciprocal,
|
||||
layer.aclnn_input_offset,
|
||||
)
|
||||
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
||||
if is_310p():
|
||||
# On 300I Duo platform, we need transpose again if
|
||||
# using nz. This transpose can be skipped in torchair.
|
||||
output = torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
layer.weight.data.transpose(1, 0),
|
||||
layer.deq_scale,
|
||||
bias=quant_bias,
|
||||
output_dtype=layer.params_dtype,
|
||||
)
|
||||
else:
|
||||
output = torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
layer.weight,
|
||||
layer.deq_scale,
|
||||
bias=quant_bias,
|
||||
output_dtype=layer.params_dtype,
|
||||
)
|
||||
return output
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
expanding_factor = layer.weight.data.shape[1]
|
||||
layer.aclnn_input_scale = torch.nn.Parameter(
|
||||
layer.input_scale.data.repeat(expanding_factor),
|
||||
requires_grad=False)
|
||||
layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter(
|
||||
layer.input_scale.data.repeat(expanding_factor),
|
||||
requires_grad=False)
|
||||
layer.aclnn_input_offset = torch.nn.Parameter(
|
||||
layer.input_offset.data.repeat(expanding_factor),
|
||||
requires_grad=False).to(layer.aclnn_input_scale.dtype)
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data,
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
||||
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
||||
|
||||
|
||||
class AscendW8A8FusedMoEMethod:
|
||||
"""FusedMoe method for Ascend W8A8.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
|
||||
@staticmethod
|
||||
def get_weight(num_experts: int, intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
param_dict["w13_weight"] = torch.empty(num_experts,
|
||||
2 *
|
||||
intermediate_size_per_partition,
|
||||
hidden_sizes,
|
||||
dtype=torch.int8,
|
||||
requires_grad=False)
|
||||
param_dict["w2_weight"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int8,
|
||||
requires_grad=False)
|
||||
return param_dict
|
||||
|
||||
@staticmethod
|
||||
def get_dynamic_quant_param(num_experts: int,
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
param_dict["w13_weight_scale"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
param_dict["w13_weight_offset"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float16)
|
||||
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=torch.float16)
|
||||
param_dict["w2_deq_scale"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
dtype=torch.float32)
|
||||
param_dict["w13_deq_scale"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_input_scale"] = torch.empty(num_experts,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
param_dict["w13_input_scale"] = torch.empty(num_experts,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_input_offset"] = torch.empty(num_experts,
|
||||
1,
|
||||
dtype=torch.int8)
|
||||
param_dict["w13_input_offset"] = torch.empty(num_experts,
|
||||
1,
|
||||
dtype=torch.int8)
|
||||
param_dict["quant_bias"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
dtype=torch.int32)
|
||||
|
||||
return param_dict
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = False,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
if is_310p():
|
||||
return fused_experts_310p(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w1_input_scale=layer.w13_input_scale,
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w2_input_scale=layer.w2_input_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w1_input_scale=layer.w13_input_scale,
|
||||
w1_input_offset=layer.w13_input_offset,
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w2_input_scale=layer.w2_input_scale,
|
||||
w2_input_offset=layer.w2_input_offset,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if not is_310p():
|
||||
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
||||
layer.w13_weight_scale.data.shape[0], -1)
|
||||
|
||||
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
|
||||
layer.w13_weight_offset.data.shape[0], -1)
|
||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
|
||||
layer.w2_weight_scale.data.shape[0], -1)
|
||||
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
||||
layer.w2_weight_offset.data.shape[0], -1)
|
||||
expanding_factor_w13 = layer.w13_weight.data.shape[1]
|
||||
expanding_factor_w2 = layer.w2_weight.data.shape[1]
|
||||
|
||||
if is_310p():
|
||||
layer.w13_input_scale.data = torch.nn.Parameter(
|
||||
layer.w13_input_scale.data.max())
|
||||
layer.w2_input_scale.data = torch.nn.Parameter(
|
||||
layer.w2_input_scale.data.max())
|
||||
else:
|
||||
layer.w13_input_scale.data = torch.nn.Parameter(
|
||||
layer.w13_input_scale.data.repeat(1,
|
||||
expanding_factor_w13)[0:1])
|
||||
layer.w2_input_scale.data = torch.nn.Parameter(
|
||||
layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1])
|
||||
|
||||
layer.w13_input_offset.data = torch.nn.Parameter(
|
||||
layer.w13_input_scale.data.repeat(1, expanding_factor_w13)[0:1])
|
||||
layer.w2_input_offset.data = torch.nn.Parameter(
|
||||
layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1])
|
||||
|
||||
# converting ACL_FORMAT_FRACTAL_NZ.
|
||||
# npu_quant_grouped_matmul_dequant in eager mode does not accept
|
||||
# ACL_FORMAT_FRACTAL_NZ.
|
||||
if not is_310p():
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
|
||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
|
||||
|
||||
|
||||
class AscendC8KVCacheMethod:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.antiquant_scale_comb = None
|
||||
|
||||
@staticmethod
|
||||
def create_weights(layer) -> None:
|
||||
param_dict = {} # num_kv_heads * head_size
|
||||
param_dict["key_antiquant_scale"] = torch.empty(layer.num_kv_heads *
|
||||
layer.head_size,
|
||||
dtype=torch.float16,
|
||||
requires_grad=False)
|
||||
param_dict["value_antiquant_scale"] = torch.empty(layer.num_kv_heads *
|
||||
layer.head_size,
|
||||
dtype=torch.float16,
|
||||
requires_grad=False)
|
||||
for weight_name, weight_param in param_dict.items():
|
||||
param = torch.nn.Parameter(weight_param, requires_grad=False)
|
||||
layer.register_parameter(weight_name, param)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
self.antiquant_scale_comb = torch.cat(
|
||||
(layer.key_antiquant_scale.data.unsqueeze(0),
|
||||
layer.value_antiquant_scale.data.unsqueeze(0)),
|
||||
dim=0).to(torch.float16).contiguous()
|
||||
|
||||
def apply(self, layer, query, key, value, kv_cache, attn_metadata,
|
||||
attn_type, scale, output) -> torch.Tensor:
|
||||
num_tokens = query.shape[0]
|
||||
if attn_metadata is None:
|
||||
return output.view(num_tokens, layer.num_heads * layer.head_size)
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
|
||||
# C8
|
||||
quant_key = quant_per_tensor(
|
||||
key.view(-1, layer.num_kv_heads * layer.head_size),
|
||||
layer.key_antiquant_scale.data.view(-1), None, True)
|
||||
quant_value = quant_per_tensor(
|
||||
value.view(-1, layer.num_kv_heads * layer.head_size),
|
||||
layer.value_antiquant_scale.data.view(-1), None, True)
|
||||
|
||||
# View q k v to BSH.
|
||||
query = query.view(-1, layer.num_heads, layer.head_size)
|
||||
key = key.view(-1, layer.num_kv_heads, layer.head_size)
|
||||
value = value.view(-1, layer.num_kv_heads, layer.head_size)
|
||||
# TODO: Remove this contiguous in the future.
|
||||
value = value.contiguous()
|
||||
|
||||
if kv_cache[0].numel() > 0:
|
||||
# if key_cache is None:
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
slots = attn_metadata.slot_mapping
|
||||
|
||||
block_size = key_cache.shape[1]
|
||||
slots_indices = slots.reshape(-1, 1)
|
||||
block_indices = slots_indices // block_size
|
||||
slots_indices = slots_indices % block_size
|
||||
indices = torch.cat((block_indices, slots_indices), dim=1)
|
||||
|
||||
# C8
|
||||
torch_npu.npu_scatter_nd_update_(key_cache, indices, quant_key)
|
||||
torch_npu.npu_scatter_nd_update_(value_cache, indices, quant_value)
|
||||
|
||||
# V0-Style scheduler situation.
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
mask = attn_metadata.attn_mask
|
||||
torch_npu._npu_flash_attention(query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seq_len=attn_metadata.seq_lens,
|
||||
scale_value=scale,
|
||||
num_heads=layer.num_heads,
|
||||
num_kv_heads=layer.num_kv_heads,
|
||||
out=output.reshape(query.shape))
|
||||
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
||||
raise NotImplementedError("kv cache int8 are not "
|
||||
"implemented for "
|
||||
"PrefillCacheHit")
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: # changed attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|
||||
if hasattr(attn_metadata, "decode"):
|
||||
# torch_air
|
||||
decode_meta = attn_metadata.decode
|
||||
seq_lens = decode_meta.seq_lens_list
|
||||
else:
|
||||
seq_lens = attn_metadata.seq_lens
|
||||
block_size = key_cache.shape[1]
|
||||
query = query.view(num_tokens, 1, layer.num_heads *
|
||||
layer.head_size).contiguous() # changed
|
||||
|
||||
# [num_blocks, block_size, N, D] --> [num_blocks, N, block_size, D]
|
||||
key = key_cache
|
||||
value = value_cache
|
||||
|
||||
output = torch_npu.npu_incre_flash_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
num_key_value_heads=layer.num_kv_heads,
|
||||
num_heads=layer.num_heads,
|
||||
actual_seq_lengths=seq_lens,
|
||||
scale_value=scale,
|
||||
input_layout='BSH',
|
||||
block_size=block_size,
|
||||
block_table=attn_metadata.block_tables,
|
||||
antiquant_scale=self.antiquant_scale_comb,
|
||||
)
|
||||
|
||||
# Normal V1 situation.
|
||||
else:
|
||||
raise NotImplementedError("kv cache int8 are not "
|
||||
"implemented for "
|
||||
"other case")
|
||||
return output
|
||||
|
||||
|
||||
def fused_experts_310p(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w1_input_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
ep_size = get_ep_group().world_size
|
||||
local_num_experts = global_num_experts // ep_size
|
||||
local_num_group = top_k // ep_size
|
||||
|
||||
bsz, _ = hidden_states.shape
|
||||
flatten_topk_ids = topk_ids.view(-1)
|
||||
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
||||
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
|
||||
sorted_hidden_states = hidden_states.index_select(
|
||||
0, sorted_topk_ids // local_num_group)
|
||||
|
||||
experts_id = torch.arange(0,
|
||||
local_num_experts,
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device)
|
||||
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
||||
torch.float32).sum(0)
|
||||
topk_scales = topk_weights.view(-1).index_select(
|
||||
0, sorted_topk_ids).unsqueeze(-1)
|
||||
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
||||
|
||||
gate_up_out = torch_npu.npu_quant_grouped_matmul_dequant(
|
||||
x=sorted_hidden_states,
|
||||
quantized_weight=w1,
|
||||
weight_scale=w1_scale,
|
||||
group_list=group_list,
|
||||
x_scale=w1_input_scale,
|
||||
quant_mode="pertensor")
|
||||
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
||||
torch.float16)
|
||||
gate_up_out *= topk_scales
|
||||
|
||||
down_out = torch_npu.npu_quant_grouped_matmul_dequant(
|
||||
x=gate_up_out,
|
||||
quantized_weight=w2,
|
||||
weight_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
x_scale=w2_input_scale,
|
||||
quant_mode="pertensor")
|
||||
|
||||
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
|
||||
unsorted_hidden_states = down_out.index_select(0, unsorted_topk_ids)
|
||||
final_hidden_states = unsorted_hidden_states.reshape(
|
||||
bsz, top_k // ep_size, -1).sum(1)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w1_input_scale: torch.Tensor,
|
||||
w1_input_offset: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
w2_input_offset: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Fused experts with top-k routing.
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
||||
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
|
||||
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
|
||||
topk_weights: Routing weights of shape (num_tokens, top_k).
|
||||
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
||||
top_k: Number of experts to select.
|
||||
expert_map: Expert mapping of shape (num_experts,).
|
||||
|
||||
Returns:
|
||||
hidden_states: Hidden states after routing.
|
||||
"""
|
||||
"""
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
"""
|
||||
|
||||
original_dtype = hidden_states.dtype
|
||||
ep_size = get_ep_group().world_size
|
||||
local_num_experts = global_num_experts // ep_size
|
||||
w1_input_scale, _ = w1_input_scale.max(0)
|
||||
quant_sorted_hidden_states = quant_per_tensor(
|
||||
hidden_states,
|
||||
w1_input_scale,
|
||||
None,
|
||||
True,
|
||||
)
|
||||
if expert_map is not None:
|
||||
expanded_x, expanded_row_idx, expert_token_count, expanded_scale = torch_npu.npu_moe_init_routing_v2(
|
||||
quant_sorted_hidden_states,
|
||||
topk_ids,
|
||||
scale=None,
|
||||
active_num=topk_ids.numel(),
|
||||
expert_capacity=-1,
|
||||
expert_num=local_num_experts,
|
||||
drop_pad_mode=0,
|
||||
expert_tokens_num_type=1,
|
||||
expert_tokens_num_flag=True,
|
||||
quant_mode=-1,
|
||||
active_expert_range=[0, local_num_experts],
|
||||
row_idx_type=0,
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"The quantified version of MOE class models "
|
||||
"currently does not support tensor parallelism")
|
||||
if expanded_x.dtype != w1.dtype:
|
||||
w1_input_scale, _ = w1_input_scale.max(0)
|
||||
quant_sorted_hidden_states = quant_per_tensor(
|
||||
expanded_x,
|
||||
w1_input_scale,
|
||||
None,
|
||||
True,
|
||||
)
|
||||
else:
|
||||
quant_sorted_hidden_states = expanded_x
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||
x=[quant_sorted_hidden_states],
|
||||
weight=[w1],
|
||||
scale=[w1_scale * w1_input_scale[0]],
|
||||
split_item=2,
|
||||
group_list_type=1,
|
||||
group_type=0,
|
||||
group_list=expert_token_count,
|
||||
output_dtype=original_dtype,
|
||||
)[0]
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
|
||||
if gate_up_out.dtype != w2.dtype:
|
||||
w2_input_scale, _ = w2_input_scale.max(0)
|
||||
quant_gate_up_out = quant_per_tensor(
|
||||
gate_up_out,
|
||||
w2_input_scale,
|
||||
None,
|
||||
True,
|
||||
)
|
||||
else:
|
||||
quant_gate_up_out = gate_up_out
|
||||
|
||||
down_out = torch_npu.npu_grouped_matmul(
|
||||
x=[quant_gate_up_out],
|
||||
weight=[w2],
|
||||
scale=[w2_scale * w2_input_scale[0]],
|
||||
split_item=2,
|
||||
group_list_type=1,
|
||||
group_type=0,
|
||||
group_list=expert_token_count,
|
||||
output_dtype=original_dtype,
|
||||
)[0]
|
||||
|
||||
if expert_map is not None:
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
down_out,
|
||||
skip1=None,
|
||||
skip2=None,
|
||||
bias=None,
|
||||
scales=topk_weights.to(down_out.dtype),
|
||||
expanded_src_to_dst_row=expanded_row_idx,
|
||||
export_for_source_row=topk_ids,
|
||||
drop_pad_mode=2,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"The quantified version of MOE class models "
|
||||
"currently does not support tensor parallelism")
|
||||
|
||||
return final_hidden_states
|
||||
453
vllm_ascend/quantization/w8a8_dynamic.py
Normal file
453
vllm_ascend/quantization/w8a8_dynamic.py
Normal file
@@ -0,0 +1,453 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import CompilationLevel, get_current_vllm_config
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.common_fused_moe import \
|
||||
fused_experts as unified_fused_experts
|
||||
from vllm_ascend.ops.fused_moe import unified_fused_experts_eager
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, dispose_tensor
|
||||
|
||||
|
||||
def apply_mlp_decode(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
group_list_type: int = 1) -> torch.Tensor:
|
||||
"""
|
||||
apply MLP: gate_up_proj -> swiglu -> down_proj
|
||||
Args:
|
||||
hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
|
||||
w1: expert weights1 with shape
|
||||
(num_experts, hidden_size, intermediate_size * 2)
|
||||
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
|
||||
w2: expert weights2 with shape
|
||||
(num_experts, intermediate_size, hidden_size)
|
||||
w2_scale: weights2 scale with shape (num_experts, hidden_size)
|
||||
group_list: number of tokens for each expert, follow cumsum mode, and
|
||||
with shape (num_experts).
|
||||
transpose_weight:
|
||||
w1: (num_experts, intermediate_size * 2, hidden_size) ->
|
||||
(num_experts, hidden_size, intermediate_size * 2)
|
||||
w2: (num_experts, hidden_size, intermediate_size) ->
|
||||
(num_experts, intermediate_size, hidden_size)
|
||||
Returns:
|
||||
hidden_states: output hidden states after MLP.
|
||||
"""
|
||||
|
||||
if dynamic_scale is None:
|
||||
unquantized_hidden_states = hidden_states
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
# Dispose the original unquantized hidden states
|
||||
# to save npu memory because they're no longer used.
|
||||
dispose_tensor(unquantized_hidden_states)
|
||||
else:
|
||||
pertoken_scale = dynamic_scale
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=3,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=torch.int32)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight_scale=w1_scale,
|
||||
activation_scale=pertoken_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=group_list,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
scale=[w2_scale],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=w2_scale.dtype)[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def apply_mlp(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
group_list_type: int = 1,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None) -> torch.Tensor:
|
||||
"""
|
||||
apply MLP: gate_up_proj -> swiglu -> down_proj
|
||||
|
||||
Args:
|
||||
hidden_states: input hidden states with shape (num_tokens, hidden_size).
|
||||
w1: expert weights1 with shape
|
||||
(num_experts, hidden_size, intermediate_size * 2)
|
||||
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
|
||||
w2: expert weights2 with shape
|
||||
(num_experts, intermediate_size, hidden_size)
|
||||
w2_scale: weights2 scale with shape (num_experts, hidden_size)
|
||||
group_list: number of tokens for each expert, follow cumsum mode, and
|
||||
with shape (num_experts).
|
||||
transpose_weight:
|
||||
w1: (num_experts, intermediate_size * 2, hidden_size) ->
|
||||
(num_experts, hidden_size, intermediate_size * 2)
|
||||
w2: (num_experts, hidden_size, intermediate_size) ->
|
||||
(num_experts, intermediate_size, hidden_size)
|
||||
|
||||
Returns:
|
||||
hidden_states: output hidden states after MLP.
|
||||
"""
|
||||
|
||||
if dynamic_scale is None:
|
||||
unquantized_hidden_states = hidden_states
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
# Dispose the original unquantized hidden states
|
||||
# to save npu memory because they're no longer used.
|
||||
dispose_tensor(unquantized_hidden_states)
|
||||
else:
|
||||
pertoken_scale = dynamic_scale
|
||||
|
||||
bias1, bias2 = None, None
|
||||
_output_dtype = w2_scale.dtype
|
||||
|
||||
if w1_scale_bias is not None:
|
||||
if group_list_type == 0:
|
||||
group_list = torch.cat(
|
||||
[group_list[:1], torch.diff(group_list, dim=0)])
|
||||
group_list_type = 1
|
||||
bias1 = [w1_scale_bias]
|
||||
bias2 = [w2_scale_bias]
|
||||
# TODO w4a8 scene: dynamic acquisition of dtype in the future
|
||||
_output_dtype = torch.bfloat16
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
scale=[w1_scale],
|
||||
bias=bias1,
|
||||
per_token_scale=[pertoken_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
scale=[w2_scale],
|
||||
bias=bias2,
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AscendW8A8DynamicLinearMethod:
|
||||
"""Linear method for Ascend W8A8_DYNAMIC.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
|
||||
@staticmethod
|
||||
def get_weight(input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
params_dict = {
|
||||
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
|
||||
}
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def get_perchannel_param(
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
params_dict["weight_scale"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_offset"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
return params_dict
|
||||
|
||||
def get_pergroup_param(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def apply(
|
||||
layer: torch.nn.Module,
|
||||
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = 0,
|
||||
) -> torch.Tensor:
|
||||
config = getattr(layer, "_ascend_quant_config", {})
|
||||
if not isinstance(x, tuple):
|
||||
output_dtype = config.get("output_dtype", x.dtype)
|
||||
quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
||||
else:
|
||||
assert "output_dtype" in config.keys(), (
|
||||
f"DynamicLinearMethod needs explicitly specified `output_dtype`"
|
||||
f"for pre-quantized input, got config [{config}]")
|
||||
output_dtype = config["output_dtype"]
|
||||
quantized_x, dynamic_scale = x
|
||||
pertoken_scale = (dynamic_scale
|
||||
if config.get("pertoken_scale", True) else None)
|
||||
|
||||
output = torch_npu.npu_quant_matmul(
|
||||
quantized_x,
|
||||
layer.weight,
|
||||
layer.weight_scale,
|
||||
pertoken_scale=pertoken_scale,
|
||||
bias=bias,
|
||||
output_dtype=output_dtype,
|
||||
)
|
||||
return ((output, dynamic_scale)
|
||||
if config.get("return_scale", False) else output)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
# cast quantized weight tensors in NZ format (29) for higher inference speed
|
||||
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
|
||||
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
||||
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||
|
||||
|
||||
class AscendW8A8DynamicFusedMoEMethod:
|
||||
"""FusedMoe method for Ascend W8A8_DYNAMIC.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
|
||||
self.ep_group = get_ep_group()
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
ascend_config = get_ascend_config()
|
||||
self.use_aclgraph = (
|
||||
vllm_config.compilation_config.level == CompilationLevel.PIECEWISE
|
||||
and not vllm_config.model_config.enforce_eager
|
||||
and not ascend_config.torchair_graph_config.enabled)
|
||||
|
||||
try:
|
||||
device_group = get_mc2_group().device_group
|
||||
# TODO: Try local_rank = ep_group.rank_in_group
|
||||
local_rank = torch.distributed.get_rank(group=device_group)
|
||||
backend = device_group._get_backend(torch.device("npu"))
|
||||
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
|
||||
local_rank)
|
||||
except AttributeError:
|
||||
self.moe_all_to_all_group_name = ""
|
||||
|
||||
@staticmethod
|
||||
def get_weight(num_experts: int, intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
param_dict["w13_weight"] = torch.empty(num_experts,
|
||||
2 *
|
||||
intermediate_size_per_partition,
|
||||
hidden_sizes,
|
||||
dtype=torch.int8)
|
||||
param_dict["w2_weight"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int8)
|
||||
return param_dict
|
||||
|
||||
@staticmethod
|
||||
def get_dynamic_quant_param(num_experts: int,
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
param_dict["w13_weight_scale"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
param_dict["w13_weight_offset"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
return param_dict
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = True,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
|
||||
topk_weights, topk_ids, row_idx = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
if self.use_aclgraph:
|
||||
return unified_fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_int8_w8a8=True,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
fused_moe_state = get_forward_context().fused_moe_state
|
||||
shared_gate_up, shared_dequant_scale = None, None
|
||||
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
|
||||
share_up_out, _ = shared_experts.gate_up_proj(
|
||||
(quantized_x_for_share, dynamic_scale_for_share))
|
||||
shared_gate_up, shared_dequant_scale = share_up_out[
|
||||
0], share_up_out[1]
|
||||
|
||||
# this is a naive implementation for experts load balance so as
|
||||
# to avoid accumulating too much tokens on a single rank.
|
||||
# currently it is only activated when doing profile runs.
|
||||
if enable_force_load_balance:
|
||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
|
||||
return unified_fused_experts_eager(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
shared_experts=shared_experts,
|
||||
shared_gate_up=shared_gate_up,
|
||||
shared_dequant_scale=shared_dequant_scale,
|
||||
mc2_mask=kwargs.get("mc2_mask", None),
|
||||
with_quant=True)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.transpose_weight:
|
||||
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
if envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP:
|
||||
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
||||
layer.w13_weight_scale.data.shape[0], -1)
|
||||
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(
|
||||
torch.float32)
|
||||
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
|
||||
layer.w13_weight_offset.data.shape[0], -1)
|
||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
|
||||
layer.w2_weight_scale.data.shape[0], -1)
|
||||
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
||||
layer.w2_weight_offset.data.shape[0], -1)
|
||||
0
vllm_ascend/sample/__init__.py
Normal file
0
vllm_ascend/sample/__init__.py
Normal file
504
vllm_ascend/sample/rejection_sampler.py
Normal file
504
vllm_ascend/sample/rejection_sampler.py
Normal file
@@ -0,0 +1,504 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import vllm.v1.sample.rejection_sampler as rs
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import (RejectionSampler, compute_probs,
|
||||
generate_uniform_probs)
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
PLACEHOLDER_TOKEN_ID = -1
|
||||
GREEDY_TEMPERATURE = -1
|
||||
# Maximum number of speculative draft tokens allowed per request in a single
|
||||
# step. This value is chosen to be large enough to handle typical use cases.
|
||||
MAX_SPEC_LEN = 32
|
||||
|
||||
|
||||
class AscendRejectionSampler(RejectionSampler, nn.Module):
|
||||
"""
|
||||
The implementation strictly follows the algorithm described in
|
||||
https://arxiv.org/abs/2211.17192.
|
||||
However, we want to clarify the terminology used in the implementation:
|
||||
accepted tokens: tokens that are accepted based on the relationship
|
||||
between the "raw" draft and target probabilities.
|
||||
recovered tokens: tokens that are sampled based on the adjusted probability
|
||||
distribution, which is derived from both the draft and target
|
||||
probabilities.
|
||||
bonus tokens:
|
||||
If all proposed tokens are accepted, the bonus token is added to the
|
||||
end of the sequence. The bonus token is only sampled from the target
|
||||
probabilities. We pass in the bonus tokens instead of sampling them
|
||||
in the rejection sampler to allow for more flexibility in the
|
||||
sampling process. For example, we can use top_p, top_k sampling for
|
||||
bonus tokens, while spec decode does not support these sampling
|
||||
strategies.
|
||||
output tokens:
|
||||
Tokens are finally generated with the rejection sampler.
|
||||
output tokens = accepted tokens + recovered tokens + bonus tokens
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
metadata: SpecDecodeMetadata,
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: Optional[torch.Tensor],
|
||||
# [num_tokens, vocab_size]
|
||||
target_logits: torch.Tensor,
|
||||
# [batch_size, 1]
|
||||
bonus_token_ids: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Args:
|
||||
metadata:
|
||||
Metadata for spec decoding.
|
||||
draft_probs (Optional[torch.Tensor]):
|
||||
Probability distribution for the draft tokens. Shape is
|
||||
[num_tokens, vocab_size]. Can be None if probabilities are
|
||||
not provided, which is the case for ngram spec decode.
|
||||
target_logits (torch.Tensor):
|
||||
Target model's logits probability distribution.
|
||||
Shape is [num_tokens, vocab_size]. Here, probabilities from
|
||||
different requests are flattened into a single tensor because
|
||||
this is the shape of the output logits.
|
||||
NOTE: `target_logits` can be updated in place to save memory.
|
||||
bonus_token_ids_tensor (torch.Tensor):
|
||||
A tensor containing bonus tokens. Shape is [batch_size, 1].
|
||||
Bonus tokens are added to the end of the sequence if all
|
||||
proposed tokens are accepted. We generate the bonus tokens
|
||||
outside of the rejection sampler with the default sampling
|
||||
strategy. It allows for more flexibility in the sampling
|
||||
process such as top_p, top_k sampling.
|
||||
sampling_metadata (SamplingMetadata):
|
||||
Additional metadata needed for sampling, such as temperature,
|
||||
top-k/top-p parameters, or other relevant information.
|
||||
Returns:
|
||||
output_token_ids (torch.Tensor):
|
||||
A tensor containing the final output token IDs.
|
||||
'''
|
||||
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
||||
# [num_tokens, vocab_size]
|
||||
# NOTE(woosuk): `target_logits` can be updated in place inside the
|
||||
# `compute_probs` function.
|
||||
target_probs = compute_probs(
|
||||
target_logits,
|
||||
metadata.cu_num_draft_tokens,
|
||||
sampling_metadata,
|
||||
)
|
||||
|
||||
output_token_ids = rejection_sample(
|
||||
metadata.draft_token_ids,
|
||||
metadata.num_draft_tokens,
|
||||
metadata.max_spec_len,
|
||||
metadata.cu_num_draft_tokens,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
sampling_metadata,
|
||||
)
|
||||
return output_token_ids
|
||||
|
||||
|
||||
def rejection_sample(
|
||||
# [num_tokens]
|
||||
draft_token_ids: torch.Tensor,
|
||||
# [batch_size]
|
||||
num_draft_tokens: list[int],
|
||||
max_spec_len: int,
|
||||
# [batch_size]
|
||||
cu_num_draft_tokens: torch.Tensor,
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: Optional[torch.Tensor],
|
||||
# [num_tokens, vocab_size]
|
||||
target_probs: torch.Tensor,
|
||||
# [batch_size, 1]
|
||||
bonus_token_ids: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert draft_token_ids.ndim == 1
|
||||
assert draft_probs is None or draft_probs.ndim == 2
|
||||
assert cu_num_draft_tokens.ndim == 1
|
||||
assert target_probs.ndim == 2
|
||||
|
||||
batch_size = len(num_draft_tokens)
|
||||
num_tokens = draft_token_ids.shape[0]
|
||||
vocab_size = target_probs.shape[-1]
|
||||
device = target_probs.device
|
||||
assert draft_token_ids.is_contiguous()
|
||||
assert draft_probs is None or draft_probs.is_contiguous()
|
||||
assert target_probs.is_contiguous()
|
||||
assert bonus_token_ids.is_contiguous()
|
||||
assert target_probs.shape == (num_tokens, vocab_size)
|
||||
|
||||
# Create output buffer.
|
||||
output_token_ids = torch.empty(
|
||||
(batch_size, max_spec_len + 1),
|
||||
dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids.
|
||||
device=device,
|
||||
)
|
||||
output_token_ids.fill_(PLACEHOLDER_TOKEN_ID)
|
||||
|
||||
if sampling_metadata.all_greedy:
|
||||
is_greedy = None
|
||||
else:
|
||||
is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
|
||||
if not sampling_metadata.all_random:
|
||||
# Rejection sampling for greedy sampling requests.
|
||||
target_argmax = target_probs.argmax(dim=-1)
|
||||
if min(num_draft_tokens) == 1 and max(
|
||||
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
|
||||
rejection_greedy_sample_spec_len_1_pytorch(
|
||||
output_token_ids,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
)
|
||||
else:
|
||||
rejection_greedy_sample_pytorch(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
num_draft_tokens,
|
||||
max_spec_len,
|
||||
is_greedy,
|
||||
)
|
||||
if sampling_metadata.all_greedy:
|
||||
return output_token_ids
|
||||
|
||||
# Generate uniform probabilities for rejection sampling.
|
||||
# [num_tokens]
|
||||
uniform_probs = generate_uniform_probs(
|
||||
num_tokens,
|
||||
num_draft_tokens,
|
||||
sampling_metadata.generators,
|
||||
device,
|
||||
)
|
||||
|
||||
# Sample recovered tokens for each position.
|
||||
# [num_tokens]
|
||||
recovered_token_ids = sample_recovered_tokens(
|
||||
max_spec_len,
|
||||
num_draft_tokens,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
sampling_metadata,
|
||||
device,
|
||||
)
|
||||
|
||||
# Rejection sampling for random sampling requests.
|
||||
rejection_random_sample_pytorch(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
recovered_token_ids,
|
||||
uniform_probs,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
IS_NGRAM=draft_probs is None,
|
||||
# num_warps=1,
|
||||
)
|
||||
return output_token_ids
|
||||
|
||||
|
||||
def expand_batch_to_tokens(
|
||||
x: torch.Tensor, # [batch_size]
|
||||
cu_num_tokens: torch.Tensor, # [batch_size]
|
||||
num_tokens: int,
|
||||
replace_from: int = 0,
|
||||
replace_to: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Expand [batch_size] tensor to [num_tokens] tensor based on the number of
|
||||
tokens per batch in cu_num_tokens.
|
||||
|
||||
For example, if x = [a, b, c] and cu_num_tokens = [2, 5, 6], then
|
||||
num_tokens = 6, and expanded_x = [a, a, b, b, b, c].
|
||||
|
||||
Args:
|
||||
x: [batch_size] tensor to expand.
|
||||
cu_num_tokens: [batch_size] tensor containing the cumulative number of
|
||||
tokens per batch. Each element represents the total number of
|
||||
tokens up to and including that batch.
|
||||
num_tokens: Total number of tokens.
|
||||
replace_from: int = 0
|
||||
Value to be replaced if it is found in x.
|
||||
replace_to: int = 0
|
||||
Value to replace with when replace_from is found.
|
||||
Returns:
|
||||
expanded_x: [num_tokens] tensor.
|
||||
"""
|
||||
batch_size = x.shape[0]
|
||||
assert cu_num_tokens.shape[0] == batch_size
|
||||
expanded_x = x.new_empty(num_tokens)
|
||||
expand_pytorch(
|
||||
expanded_x,
|
||||
x,
|
||||
cu_num_tokens,
|
||||
replace_from,
|
||||
replace_to,
|
||||
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
|
||||
)
|
||||
return expanded_x
|
||||
|
||||
|
||||
def sample_recovered_tokens(
|
||||
max_spec_len: int,
|
||||
num_draft_tokens: list[int],
|
||||
# [batch_size]
|
||||
cu_num_draft_tokens: torch.Tensor,
|
||||
# [num_tokens]
|
||||
draft_token_ids: torch.Tensor,
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: Optional[torch.Tensor],
|
||||
# [num_tokens, vocab_size]
|
||||
target_probs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(woosuk): Create only one distribution for each request.
|
||||
batch_size = len(num_draft_tokens)
|
||||
vocab_size = target_probs.shape[-1]
|
||||
q = torch.empty(
|
||||
(batch_size, vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
q.exponential_()
|
||||
for i, generator in sampling_metadata.generators.items():
|
||||
# Do not generate random numbers for requests with no draft tokens.
|
||||
# This can be important for reproducibility.
|
||||
if num_draft_tokens[i] > 0:
|
||||
q[i].exponential_(generator=generator)
|
||||
|
||||
recovered_token_ids = torch.empty_like(draft_token_ids)
|
||||
sample_recovered_tokens_pytorch(
|
||||
recovered_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
q,
|
||||
vocab_size,
|
||||
IS_NGRAM=draft_probs is None,
|
||||
)
|
||||
return recovered_token_ids
|
||||
|
||||
|
||||
def rejection_greedy_sample_spec_len_1_pytorch(
|
||||
output_token_ids, # [batch_size, 2]
|
||||
draft_token_ids, # [num_tokens]
|
||||
target_argmax, # [num_tokens]
|
||||
bonus_token_ids, # [batch_size]
|
||||
):
|
||||
batch_size = output_token_ids.size(0)
|
||||
num_tokens = draft_token_ids.size(0)
|
||||
assert batch_size == num_tokens
|
||||
accept_req_mask = draft_token_ids == target_argmax
|
||||
output_token_ids[:, 0] = target_argmax
|
||||
bonus_token_ids = bonus_token_ids.squeeze(1)
|
||||
output_token_ids[accept_req_mask, 1] = bonus_token_ids[accept_req_mask]
|
||||
|
||||
|
||||
def rejection_greedy_sample_pytorch(
|
||||
output_token_ids, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens, # [batch_size]
|
||||
draft_token_ids, # [num_tokens]
|
||||
target_argmax, # [num_tokens]
|
||||
bonus_token_ids, # [batch_size]
|
||||
draft_tokens_per_req, # [batch_size], list
|
||||
max_spec_len,
|
||||
is_greedy=None, # [batch_size] or None
|
||||
):
|
||||
batch_size = output_token_ids.size(0)
|
||||
num_tokens = draft_token_ids.size(0)
|
||||
device = output_token_ids.device
|
||||
draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to(
|
||||
device, non_blocking=True)
|
||||
if is_greedy is None:
|
||||
is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device)
|
||||
|
||||
start_indices = cu_num_draft_tokens - draft_tokens_per_req
|
||||
req_ids = torch.arange(batch_size, device=device)
|
||||
token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req)
|
||||
token_positions = torch.arange(
|
||||
num_tokens, device=device) - start_indices[token_req_ids]
|
||||
|
||||
# Find the first mismatch position of each request.
|
||||
mismatch_global = (draft_token_ids != target_argmax)
|
||||
if max_spec_len == 0:
|
||||
first_mismatch_pos_per_req = torch.zeros(batch_size,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
else:
|
||||
# [bs, max_spec_len]
|
||||
pos_matrix = torch.full((batch_size, max_spec_len),
|
||||
-1,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
pos_matrix[token_req_ids, token_positions] = token_positions
|
||||
mismatch_matrix = torch.full((batch_size, max_spec_len),
|
||||
False,
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
mismatch_matrix[token_req_ids, token_positions] = mismatch_global
|
||||
mismatch_positions = torch.where(mismatch_matrix, pos_matrix,
|
||||
max_spec_len * 2)
|
||||
first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1)
|
||||
no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2)
|
||||
first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[
|
||||
no_mismatch_mask]
|
||||
|
||||
# Copy matched target tokens into output.
|
||||
copy_len = torch.minimum(first_mismatch_pos_per_req + 1,
|
||||
draft_tokens_per_req)
|
||||
copy_indices = torch.arange(max_spec_len + 1,
|
||||
device=device).expand(batch_size, -1)
|
||||
copy_mask = copy_indices < copy_len.unsqueeze(1)
|
||||
greedy_mask = is_greedy.unsqueeze(1)
|
||||
final_copy_mask = copy_mask & greedy_mask
|
||||
global_idx = start_indices.unsqueeze(1) + copy_indices
|
||||
output_token_ids[final_copy_mask] = target_argmax[
|
||||
global_idx[final_copy_mask]].to(output_token_ids.dtype)
|
||||
# Fill bonus token.
|
||||
needs_bonus = is_greedy & (first_mismatch_pos_per_req
|
||||
>= draft_tokens_per_req)
|
||||
if torch.any(needs_bonus):
|
||||
bonus_rows = torch.where(needs_bonus)[0]
|
||||
bonus_cols = draft_tokens_per_req[bonus_rows]
|
||||
bonus_token_ids = bonus_token_ids.squeeze(1)
|
||||
output_token_ids[bonus_rows, bonus_cols] = bonus_token_ids[bonus_rows]
|
||||
|
||||
|
||||
def rejection_random_sample_pytorch(
|
||||
output_token_ids, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens, # [batch_size]
|
||||
draft_token_ids, # [num_tokens]
|
||||
draft_probs, # [num_tokens, vocab_size] or None
|
||||
target_probs, # [num_tokens, vocab_size]
|
||||
bonus_token_ids, # [batch_size]
|
||||
recovered_token_ids, # [num_tokens]
|
||||
uniform_probs, # [num_tokens]
|
||||
is_greedy, # [batch_size]
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
IS_NGRAM=False,
|
||||
):
|
||||
batch_size = output_token_ids.shape[0]
|
||||
|
||||
for req_idx in range(batch_size):
|
||||
if is_greedy[req_idx]:
|
||||
continue
|
||||
|
||||
if req_idx == 0:
|
||||
start_idx = 0
|
||||
else:
|
||||
start_idx = cu_num_draft_tokens[req_idx - 1].item()
|
||||
end_idx = cu_num_draft_tokens[req_idx].item()
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
|
||||
rejected = False
|
||||
for pos in range(num_draft_tokens):
|
||||
if not rejected:
|
||||
draft_token_id = draft_token_ids[start_idx + pos].item()
|
||||
|
||||
if IS_NGRAM:
|
||||
draft_prob = 1.0
|
||||
else:
|
||||
draft_prob = draft_probs[start_idx + pos,
|
||||
draft_token_id].item()
|
||||
|
||||
target_prob = target_probs[start_idx + pos,
|
||||
draft_token_id].item()
|
||||
uniform_prob = uniform_probs[start_idx + pos].item()
|
||||
|
||||
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
|
||||
token_id = draft_token_id
|
||||
else:
|
||||
rejected = True
|
||||
token_id = recovered_token_ids[start_idx + pos].item()
|
||||
|
||||
output_token_ids[req_idx, pos] = token_id
|
||||
|
||||
if not rejected:
|
||||
bonus_token_id = bonus_token_ids[req_idx].item()
|
||||
output_token_ids[req_idx, num_draft_tokens] = bonus_token_id
|
||||
|
||||
|
||||
def expand_pytorch(
|
||||
output_ptr, # [num_tokens]
|
||||
input_ptr, # [batch_size]
|
||||
cu_num_tokens_ptr, # [batch_size]
|
||||
replace_from,
|
||||
replace_to,
|
||||
MAX_NUM_TOKENS,
|
||||
):
|
||||
batch_size = len(input_ptr)
|
||||
|
||||
for req_idx in range(batch_size):
|
||||
start_idx = 0 if req_idx == 0 else cu_num_tokens_ptr[req_idx - 1]
|
||||
end_idx = cu_num_tokens_ptr[req_idx]
|
||||
num_tokens = end_idx - start_idx
|
||||
|
||||
src_val = input_ptr[req_idx]
|
||||
src_val = replace_to if src_val == replace_from else src_val
|
||||
|
||||
offset = torch.arange(MAX_NUM_TOKENS, device=num_tokens.device)
|
||||
mask = offset < num_tokens
|
||||
|
||||
output_slice = start_idx + offset[mask]
|
||||
output_ptr[output_slice] = src_val
|
||||
|
||||
|
||||
def sample_recovered_tokens_pytorch(
|
||||
output_token_ids, # [num_tokens]
|
||||
cu_num_draft_tokens, # [batch_size]
|
||||
draft_token_ids, # [num_tokens]
|
||||
draft_probs, # [num_tokens, vocab_size] or None
|
||||
target_probs, # [num_tokens, vocab_size]
|
||||
q, # [batch_size, vocab_size]
|
||||
vocab_size,
|
||||
IS_NGRAM=False,
|
||||
):
|
||||
batch_size = len(cu_num_draft_tokens)
|
||||
|
||||
for req_idx in range(batch_size):
|
||||
start_idx = 0 if req_idx == 0 else cu_num_draft_tokens[req_idx - 1]
|
||||
end_idx = cu_num_draft_tokens[req_idx]
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
|
||||
for pos in range(num_draft_tokens):
|
||||
token_idx = start_idx + pos
|
||||
|
||||
if IS_NGRAM:
|
||||
draft_token_id = draft_token_ids[token_idx]
|
||||
orig_prob = target_probs[token_idx, draft_token_id].item()
|
||||
target_probs[token_idx, draft_token_id] = 0
|
||||
prob = target_probs[token_idx].clone()
|
||||
else:
|
||||
draft_p = draft_probs[token_idx].clone()
|
||||
target_p = target_probs[token_idx].clone()
|
||||
prob = torch.maximum(target_p - draft_p,
|
||||
torch.tensor(0.0, device=target_p.device))
|
||||
|
||||
q_values = torch.full((vocab_size, ),
|
||||
float('-inf'),
|
||||
device=q.device)
|
||||
q_values[:vocab_size] = q[req_idx, :vocab_size]
|
||||
|
||||
recovered_id = torch.argmax(prob / q_values).item()
|
||||
output_token_ids[token_idx] = recovered_id
|
||||
|
||||
if IS_NGRAM:
|
||||
target_probs[token_idx, draft_token_id] = orig_prob
|
||||
|
||||
|
||||
rs.expand_batch_to_tokens = expand_batch_to_tokens
|
||||
86
vllm_ascend/sample/sampler.py
Normal file
86
vllm_ascend/sample/sampler.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
|
||||
from vllm_ascend.utils import is_310p, vllm_version_is
|
||||
|
||||
if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")):
|
||||
from vllm.config import LogprobsMode
|
||||
DEFAULT_LOGPROBS_MODE = LogprobsMode.RAW_LOGPROBS
|
||||
else:
|
||||
LogprobsMode = None
|
||||
DEFAULT_LOGPROBS_MODE = "raw_logprobs"
|
||||
|
||||
|
||||
class AscendSampler(Sampler):
|
||||
|
||||
def __init__(self, logprobs_mode=DEFAULT_LOGPROBS_MODE):
|
||||
# TODO: support logprobs_mode in vllm-ascend
|
||||
super().__init__(logprobs_mode=logprobs_mode)
|
||||
self.topk_topp_sampler = AscendTopKTopPSampler()
|
||||
|
||||
|
||||
class AscendTopKTopPSampler(TopKTopPSampler):
|
||||
|
||||
def _apply_top_k_top_p(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# npu_top_k_top_p uses the operator aclnnApplyTopKTopP, but aclnnApplyTopKTopP currently does not support 310P
|
||||
if not is_310p() and p is not None and k is not None:
|
||||
# npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p)
|
||||
return torch_npu.npu_top_k_top_p(logits, p, k)
|
||||
|
||||
if p is None and k is None:
|
||||
return logits
|
||||
|
||||
probs = logits.softmax(dim=-1)
|
||||
probs_sort, _ = probs.sort(dim=-1, descending=False)
|
||||
|
||||
if k is not None:
|
||||
top_k_count = probs_sort.size(1) - k.to(
|
||||
torch.long) # shape: (batch, )
|
||||
top_k_count = top_k_count.unsqueeze(dim=1)
|
||||
top_k_cutoff = probs_sort.gather(-1, top_k_count)
|
||||
|
||||
# Make sure the no top-k rows are no-op.
|
||||
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
|
||||
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
|
||||
|
||||
elements_to_discard = probs < top_k_cutoff
|
||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||
|
||||
if p is not None:
|
||||
cumprob = torch.cumsum(probs_sort, dim=-1)
|
||||
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
|
||||
top_p_mask[:, -1] = False # at least one
|
||||
|
||||
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
|
||||
top_p_cutoff = probs_sort.gather(-1, top_p_count)
|
||||
elements_to_discard = probs < top_p_cutoff
|
||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||
|
||||
return logits
|
||||
|
||||
def forward_native(self, logits, generators, k, p):
|
||||
"""Override pytorch native implementation to torch_npu"""
|
||||
logits = self._apply_top_k_top_p(logits, k, p)
|
||||
if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")):
|
||||
|
||||
logits_to_return = None
|
||||
if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS:
|
||||
logits_to_return = logits
|
||||
elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS:
|
||||
logits_to_return = logits.log_softmax(dim=-1,
|
||||
dtype=torch.float32)
|
||||
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
output = None
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
output = random_sample(probs, generators)
|
||||
else:
|
||||
output = (random_sample(probs, generators), logits_to_return)
|
||||
return output
|
||||
0
vllm_ascend/torchair/__init__.py
Normal file
0
vllm_ascend/torchair/__init__.py
Normal file
0
vllm_ascend/torchair/models/__init__.py
Normal file
0
vllm_ascend/torchair/models/__init__.py
Normal file
364
vllm_ascend/torchair/models/qwen2.py
Normal file
364
vllm_ascend/torchair/models/qwen2.py
Normal file
@@ -0,0 +1,364 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import vllm
|
||||
import vllm.envs as envs
|
||||
from torch import nn
|
||||
from transformers import Qwen2Config
|
||||
from vllm.attention import AttentionMetadata, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_reduce_scatter)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Attention # noqa: F401
|
||||
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM # noqa: F401
|
||||
from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Model
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader,
|
||||
PPMissingLayer, maybe_prefix)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
|
||||
|
||||
def all_gather_and_maybe_unpad(
|
||||
hidden_states: torch.Tensor,
|
||||
pad_size: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
|
||||
if pad_size > 0:
|
||||
return hidden_states[:-pad_size, :]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def maybe_pad_and_reduce_scatter(
|
||||
hidden_states: torch.Tensor,
|
||||
pad_size: int,
|
||||
) -> torch.Tensor:
|
||||
if pad_size > 0:
|
||||
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_size))
|
||||
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, 0)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomQwen2Attention(Qwen2Attention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
rope_scaling: Optional[tuple] = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
max_position=max_position,
|
||||
rope_theta=rope_theta,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
rope_scaling=rope_scaling,
|
||||
prefix=prefix,
|
||||
attn_type=attn_type,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config)
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
if self.torchair_graph_enabled and attn_metadata is not None and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
q, k = self.rotary_emb(positions,
|
||||
q,
|
||||
k,
|
||||
is_prefill=False,
|
||||
is_qwen_torchair=True)
|
||||
forward_kwargs = {}
|
||||
if envs.VLLM_USE_V1:
|
||||
output_shape = q.shape
|
||||
output = torch.empty(output_shape,
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
forward_kwargs['output'] = output
|
||||
|
||||
attn_output = self.attn.impl.forward(self.attn,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
trace_flag=False,
|
||||
**forward_kwargs)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
else:
|
||||
if type(self.rotary_emb) is RotaryEmbedding:
|
||||
q, k = self.rotary_emb(positions, q, k, is_qwen_torchair=True)
|
||||
else:
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class CustomQwen2DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 1000000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
dual_chunk_attention_config = getattr(config,
|
||||
"dual_chunk_attention_config",
|
||||
None)
|
||||
|
||||
# By default, Qwen2 uses causal attention as it is a decoder-only model.
|
||||
# You can override the HF config with `is_causal=False` to enable
|
||||
# bidirectional attention, which is used in some embedding models
|
||||
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
|
||||
if getattr(config, "is_causal", True):
|
||||
attn_type = AttentionType.DECODER
|
||||
else:
|
||||
attn_type = AttentionType.ENCODER_ONLY
|
||||
|
||||
self.self_attn = CustomQwen2Attention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
max_position=config.max_position_embeddings,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
rope_scaling=rope_scaling,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_type=attn_type,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
self.mlp = Qwen2MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
hidden_states = self.self_attn(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"input_ids": 0,
|
||||
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
|
||||
# otherwise (seq_len, ).
|
||||
"positions": -1,
|
||||
"intermediate_tensors": 0,
|
||||
"inputs_embeds": 0,
|
||||
})
|
||||
class CustomQwen2Model(Qwen2Model):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
decoder_layer_type: type[nn.Module] = CustomQwen2DecoderLayer):
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
decoder_layer_type=decoder_layer_type)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
kv_cache = kv_caches[i - self.start_layer] \
|
||||
if kv_caches is not None else None
|
||||
hidden_states, residual = layer(positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomQwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
# add `CustomQwen2Model` to init self.model
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = CustomQwen2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "lm_head"))
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
vllm.model_executor.models.qwen2.Qwen2ForCausalLM = CustomQwen2ForCausalLM
|
||||
537
vllm_ascend/torchair/models/qwen3_moe.py
Normal file
537
vllm_ascend/torchair/models/qwen3_moe.py
Normal file
@@ -0,0 +1,537 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2024 The Qwen team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Adapted from vllm/model_executor/models/qwen3_moe.py
|
||||
# This file is a part of the vllm-ascend project.
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import vllm.envs as envs
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, CompilationLevel, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
||||
get_tp_group)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.interfaces import (MixtureOfExperts,
|
||||
SupportsLoRA, SupportsPP)
|
||||
from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
|
||||
Qwen3MoeDecoderLayer,
|
||||
Qwen3MoeForCausalLM,
|
||||
Qwen3MoeMLP, Qwen3MoeModel,
|
||||
Qwen3MoeSparseMoeBlock)
|
||||
from vllm.model_executor.models.utils import (
|
||||
PPMissingLayer, extract_layer_index,
|
||||
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.ops.sequence_parallel import (MetadataForPadding,
|
||||
init_metadata_for_sp)
|
||||
|
||||
|
||||
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
if self.tp_size > config.num_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.num_experts}.")
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.num_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
self.experts = AscendFusedMoE(
|
||||
num_experts=config.num_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
)
|
||||
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
self.dp_size = get_dp_group().world_size
|
||||
|
||||
self.tp_group = get_tp_group().device_group
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
self.ep_group = get_ep_group()
|
||||
|
||||
self.params_dtype = torch.get_default_dtype()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attn_metadata=None,
|
||||
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
||||
):
|
||||
if attn_metadata is None:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
# when profile runs, force experts to load balanced tokens
|
||||
# to avoid high memory consumption on a single rank.
|
||||
enable_force_load_balance = get_forward_context().in_profile_run
|
||||
is_prefill = get_forward_context().with_prefill
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
is_prefill=is_prefill,
|
||||
top_k=self.top_k,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
shared_experts=None,
|
||||
_metadata_for_padding=_metadata_for_padding,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomQwen3MoeAttention(Qwen3MoeAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
head_dim: Optional[int] = None,
|
||||
rms_norm_eps: float = 1e-06,
|
||||
qkv_bias: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = head_dim or (hidden_size // self.total_num_heads)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj")
|
||||
|
||||
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
|
||||
@staticmethod
|
||||
def normalize_qkv(qkv: torch.Tensor, q_size: int, kv_size: int,
|
||||
head_dim: int, q_norm, k_norm):
|
||||
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
|
||||
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim)
|
||||
q_by_head = q_norm(q_by_head)
|
||||
q = q_by_head.view(q.shape)
|
||||
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim)
|
||||
k_by_head = k_norm(k_by_head)
|
||||
k = k_by_head.view(k.shape)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = self.normalize_qkv(qkv, self.q_size, self.kv_size,
|
||||
self.head_dim, self.q_norm, self.k_norm)
|
||||
|
||||
if (self.torchair_graph_enabled and attn_metadata is not None and
|
||||
attn_metadata.attn_state == AscendAttentionState.DecodeOnly):
|
||||
q, k = self.rotary_emb(positions,
|
||||
q,
|
||||
k,
|
||||
is_prefill=False,
|
||||
is_qwen_torchair=True)
|
||||
forward_kwargs = {}
|
||||
if envs.VLLM_USE_V1:
|
||||
output_shape = q.shape
|
||||
output = torch.empty(output_shape,
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
forward_kwargs['output'] = output
|
||||
|
||||
attn_output = self.attn.impl.forward(self.attn,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
trace_flag=False,
|
||||
**forward_kwargs)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
else:
|
||||
q, k = self.rotary_emb(positions, q, k, is_qwen_torchair=True)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
vllm_config: Optional[VllmConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.self_attn = CustomQwen3MoeAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
qkv_bias=getattr(config, 'attention_bias', False),
|
||||
head_dim=getattr(config, 'head_dim', None),
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
|
||||
# `mlp_only_layers` in the config.
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
|
||||
config.mlp_only_layers)
|
||||
self.use_aclgraph = (vllm_config is not None
|
||||
and vllm_config.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE
|
||||
and not vllm_config.model_config.enforce_eager)
|
||||
if (layer_idx not in mlp_only_layers) and (
|
||||
config.num_experts > 0 and
|
||||
(layer_idx + 1) % config.decoder_sparse_step == 0):
|
||||
if not self.use_aclgraph:
|
||||
# FIXME: custom sparse moe block doesn't work with aclgraph.
|
||||
self.mlp = CustomSparseMoeBlock(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
self.enable_sequence_parallelism = (
|
||||
vllm_config.compilation_config.pass_config.
|
||||
enable_sequence_parallelism if vllm_config is not None else False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# To prevent precision issues during the decoder phase when only prefilling enables SP
|
||||
if not self.enable_sequence_parallelism:
|
||||
self.self_attn.o_proj.reduce_results = True
|
||||
else:
|
||||
self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True
|
||||
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
residual = _metadata_for_padding.padding_slice(residual)
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
|
||||
hidden_states)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(
|
||||
hidden_states)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
if not self.use_aclgraph:
|
||||
hidden_states = self.mlp(
|
||||
hidden_states, _metadata_for_padding=_metadata_for_padding)
|
||||
else:
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class CustomQwen3MoeModel(Qwen3MoeModel):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.num_redundant_experts = parallel_config.num_redundant_experts
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.config = config
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=f"{prefix}.embed_tokens")
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: CustomQwen3MoeDecoderLayer(
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
kv_caches[i -
|
||||
self.start_layer] if kv_caches is not None else None,
|
||||
attn_metadata,
|
||||
_metadata_for_padding=_metadata_for_padding)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
|
||||
hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
SupportsPP.__init__(self)
|
||||
SupportsLoRA.__init__(self)
|
||||
MixtureOfExperts.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = CustomQwen3MoeModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
|
||||
# Set MoE hyperparameters
|
||||
self.expert_weights: list[torch.Tensor] = []
|
||||
|
||||
self.moe_layers: list[FusedMoE] = []
|
||||
example_layer = None
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer, PPMissingLayer):
|
||||
continue
|
||||
|
||||
assert isinstance(layer, Qwen3MoeDecoderLayer)
|
||||
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
|
||||
example_layer = layer.mlp
|
||||
self.moe_layers.append(layer.mlp.experts)
|
||||
|
||||
if example_layer is None:
|
||||
raise RuntimeError("No Qwen3MoE layer found in the model.layers.")
|
||||
|
||||
self.num_moe_layers = len(self.moe_layers)
|
||||
self.num_expert_groups = 1
|
||||
self.num_shared_experts = 0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
_metadata_for_padding = init_metadata_for_sp(
|
||||
input_ids, self.enable_sequence_parallelism)
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds, _metadata_for_padding)
|
||||
return hidden_states
|
||||
218
vllm_ascend/torchair/models/torchair_deepseek_mtp.py
Normal file
218
vllm_ascend/torchair/models/torchair_deepseek_mtp.py
Normal file
@@ -0,0 +1,218 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Adapted from vllm/model_executor/models/deepseek_mtp.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.deepseek_mtp import (
|
||||
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
|
||||
SharedHead)
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.torchair.models.torchair_deepseek_v2 import \
|
||||
TorchairDeepseekV2DecoderLayer
|
||||
|
||||
|
||||
class TorchairDeepSeekShareHead(SharedHead):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "head"))
|
||||
|
||||
|
||||
class TorchairDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer
|
||||
):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
prefix: str,
|
||||
model_config: ModelConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
self.shared_head = TorchairDeepSeekShareHead(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix,
|
||||
"shared_head"))
|
||||
self.mtp_block = TorchairDeepseekV2DecoderLayer(
|
||||
config, prefix, model_config, cache_config, quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_index: int = 0,
|
||||
) -> torch.Tensor:
|
||||
assert inputs_embeds is not None
|
||||
# masking inputs at position 0, as not needed by MTP
|
||||
inputs_embeds = torch.where((positions == 0).unsqueeze(-1),
|
||||
torch.zeros_like(inputs_embeds),
|
||||
inputs_embeds)
|
||||
inputs_embeds = self.enorm(inputs_embeds)
|
||||
previous_hidden_states = self.hnorm(previous_hidden_states)
|
||||
|
||||
hidden_states = self.eh_proj(
|
||||
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
||||
|
||||
hidden_states, residual = self.mtp_block(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=None)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TorchairDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.mtp_start_layer_idx = config.num_hidden_layers
|
||||
self.num_mtp_layers = config.num_nextn_predict_layers
|
||||
# to map the exact layer index from weights
|
||||
self.layers = torch.nn.ModuleDict({
|
||||
str(idx):
|
||||
TorchairDeepSeekMultiTokenPredictorLayer(
|
||||
config,
|
||||
f"{prefix}.layers.{idx}",
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
quant_config=vllm_config.quant_config,
|
||||
)
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
})
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
|
||||
# Note: torch._dynamo.exc.Unsupported: builtin: str
|
||||
self.layers_list = [
|
||||
self.layers[str(idx)]
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
]
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
step_kv_cache = kv_caches[
|
||||
current_step_idx] if kv_caches is not None else None
|
||||
return self.layers_list[current_step_idx](
|
||||
input_ids,
|
||||
positions,
|
||||
step_kv_cache,
|
||||
attn_metadata,
|
||||
previous_hidden_states,
|
||||
inputs_embeds,
|
||||
current_step_idx,
|
||||
)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
mtp_layer = self.layers_list[current_step_idx]
|
||||
logits = self.logits_processor(mtp_layer.shared_head.head,
|
||||
mtp_layer.shared_head(hidden_states),
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
||||
class TorchairDeepSeekMTP(DeepSeekMTP):
|
||||
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
|
||||
# NOTE 2.The description file generated by the current msmodelslim tool does not have
|
||||
# MTP layer info. Please manually add it and set the value to FLOAT.
|
||||
packed_modules_mapping = {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.model = TorchairDeepSeekMultiTokenPredictor(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
self.sampler = get_sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
previous_hidden_states: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, previous_hidden_states,
|
||||
inputs_embeds, spec_step_idx)
|
||||
return hidden_states
|
||||
1049
vllm_ascend/torchair/models/torchair_deepseek_v2.py
Normal file
1049
vllm_ascend/torchair/models/torchair_deepseek_v2.py
Normal file
File diff suppressed because it is too large
Load Diff
28
vllm_ascend/torchair/models/torchair_deepseek_v3.py
Normal file
28
vllm_ascend/torchair/models/torchair_deepseek_v3.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from vllm_ascend.torchair.models.torchair_deepseek_v2 import \
|
||||
TorchairDeepseekV2ForCausalLM
|
||||
|
||||
|
||||
class TorchairDeepseekV3ForCausalLM(TorchairDeepseekV2ForCausalLM):
|
||||
pass
|
||||
1119
vllm_ascend/torchair/models/torchair_pangu_moe.py
Normal file
1119
vllm_ascend/torchair/models/torchair_pangu_moe.py
Normal file
File diff suppressed because it is too large
Load Diff
0
vllm_ascend/torchair/ops/__init__.py
Normal file
0
vllm_ascend/torchair/ops/__init__.py
Normal file
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user