forked from EngineX-Ascend/enginex-ascend-910-vllm
init v0.11.0rc0
This commit is contained in:
@@ -23,5 +23,7 @@ def register():
|
||||
|
||||
|
||||
def register_model():
|
||||
import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa
|
||||
|
||||
from .models import register_model
|
||||
register_model()
|
||||
|
||||
@@ -34,6 +34,8 @@ class AscendConfig:
|
||||
|
||||
def __init__(self, vllm_config):
|
||||
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
|
||||
self.is_deepseek_sfa = vllm_config.model_config is not None and vllm_config.model_config.is_deepseek_mla and vllm_config.model_config.hf_text_config.model_type == "deepseek_v32"
|
||||
self.use_sfa = self.is_deepseek_sfa
|
||||
|
||||
torchair_graph_config = additional_config.get("torchair_graph_config",
|
||||
{})
|
||||
@@ -43,13 +45,26 @@ class AscendConfig:
|
||||
"ascend_scheduler_config", {})
|
||||
self.ascend_scheduler_config = AscendSchedulerConfig(
|
||||
ascend_scheduler_config)
|
||||
|
||||
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this config
|
||||
self.expert_map_path = additional_config.get("expert_map_path", None)
|
||||
self.expert_map_record_path = additional_config.get(
|
||||
"expert_map_record_path",
|
||||
None) # Provide path to export expert map
|
||||
self.init_redundancy_expert = additional_config.get(
|
||||
"init_redundancy_expert", 0)
|
||||
self.dynamic_eplb = additional_config.get("dynamic_eplb", False)
|
||||
self.num_iterations_eplb_update = additional_config.get(
|
||||
"num_iterations_eplb_update", 400)
|
||||
self.gate_eplb = additional_config.get("gate_eplb", False)
|
||||
self.num_wait_worker_iterations = additional_config.get(
|
||||
"num_wait_worker_iterations", 30)
|
||||
self.chunked_prefill_for_mla = additional_config.get(
|
||||
"chunked_prefill_for_mla", False)
|
||||
self.enable_shared_expert_dp = additional_config.get(
|
||||
"enable_shared_expert_dp", False
|
||||
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
|
||||
self.multistream_overlap_shared_expert = additional_config.get(
|
||||
"multistream_overlap_shared_expert", False)
|
||||
self.enable_prefetch = additional_config.get("enable_prefetch", False)
|
||||
self.lmhead_tensor_parallel_size = additional_config.get(
|
||||
"lmhead_tensor_parallel_size", None)
|
||||
@@ -61,6 +76,24 @@ class AscendConfig:
|
||||
raise AssertionError(
|
||||
"lmhead_tensor_parallel_size is only supported in the pure DP scenario"
|
||||
)
|
||||
self.oproj_tensor_parallel_size = additional_config.get(
|
||||
"oproj_tensor_parallel_size", None)
|
||||
if self.oproj_tensor_parallel_size is not None:
|
||||
logger.info(
|
||||
f"Enable oproj_tensor_parallel_size={self.oproj_tensor_parallel_size} in pure DP scenario"
|
||||
)
|
||||
if vllm_config.parallel_config.tensor_parallel_size != 1:
|
||||
raise AssertionError(
|
||||
"oproj_tensor_parallel_size is only supported in the pure DP scenario"
|
||||
)
|
||||
if not self.torchair_graph_config.enabled:
|
||||
raise AssertionError(
|
||||
"oproj_tensor_parallel_size is only supported in graph mode"
|
||||
)
|
||||
if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer:
|
||||
raise AssertionError(
|
||||
"oproj_tensor_parallel_size is only supported in pd scenario and can only be used in D node."
|
||||
)
|
||||
|
||||
|
||||
class TorchairGraphConfig:
|
||||
@@ -81,10 +114,10 @@ class TorchairGraphConfig:
|
||||
"graph_batch_sizes_init", False)
|
||||
self.enable_multistream_mla = torchair_graph_config.get(
|
||||
"enable_multistream_mla", False)
|
||||
self.enable_multistream_moe = torchair_graph_config.get(
|
||||
"enable_multistream_moe", False)
|
||||
self.enable_view_optimize = torchair_graph_config.get(
|
||||
"enable_view_optimize", True)
|
||||
self.enable_frozen_parameter = torchair_graph_config.get(
|
||||
"enable_frozen_parameter", True)
|
||||
self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False)
|
||||
|
||||
if not isinstance(self.graph_batch_sizes, list):
|
||||
@@ -117,10 +150,6 @@ class TorchairGraphConfig:
|
||||
raise RuntimeError(
|
||||
"enable_multistream_mla is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.enable_multistream_moe:
|
||||
raise RuntimeError(
|
||||
"enable_multistream_moe is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.enable_kv_nz:
|
||||
raise RuntimeError(
|
||||
"enable_kv_nz is valid only when Torchair graph mode is enabled"
|
||||
@@ -200,14 +229,8 @@ def check_ascend_config(vllm_config, enforce_eager):
|
||||
"it has been disabled automatically.")
|
||||
# aclgraph case
|
||||
else:
|
||||
# aclgraph doesn't work with deepseek model and only qwen model is well tested.
|
||||
if vllm_config.model_config:
|
||||
model_type = vllm_config.model_config.hf_config.model_type
|
||||
if "deepseek" in model_type:
|
||||
raise NotImplementedError(
|
||||
"ACL Graph does not support deepseek. Please "
|
||||
"try torchair graph mode to serve deepseek models on vllm-ascend."
|
||||
" Or set `enforce_eager=True` to use eager mode.")
|
||||
if "qwen" not in model_type:
|
||||
logger.warning(
|
||||
"ACL Graph is currently experimental. Please "
|
||||
|
||||
@@ -11,6 +11,7 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context,
|
||||
set_forward_context)
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.utils import enable_sp
|
||||
|
||||
|
||||
class FusedMoEState(Enum):
|
||||
@@ -22,6 +23,13 @@ class FusedMoEState(Enum):
|
||||
All2AllSeq = 5
|
||||
|
||||
|
||||
class MoECommType(Enum):
|
||||
ALLGATHER = 0
|
||||
MC2 = 1
|
||||
ALLTOALL = 2
|
||||
NAIVE_MULTICAST = 3
|
||||
|
||||
|
||||
# TODO(zzzzwwjj): add soc_version to choose branch
|
||||
def _get_fused_moe_state(ep_size: int, with_prefill: bool,
|
||||
is_deepseek_v3_r1: bool):
|
||||
@@ -42,18 +50,6 @@ def _get_fused_moe_state(ep_size: int, with_prefill: bool,
|
||||
return FusedMoEState.MC2
|
||||
|
||||
|
||||
def get_dispatcher_name(ep_size: int, with_prefill: bool) -> str:
|
||||
if ep_size == 1:
|
||||
return "TokenDispatcherWithAllGather"
|
||||
|
||||
if ep_size < 16:
|
||||
return "TokenDispatcherWithAll2AllV"
|
||||
|
||||
if with_prefill:
|
||||
return "TokenDispatcherWithAll2AllV"
|
||||
return "TokenDispatcherWithMC2"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_ascend_forward_context(
|
||||
attn_metadata: Any,
|
||||
@@ -64,10 +60,12 @@ def set_ascend_forward_context(
|
||||
with_prefill: bool = True,
|
||||
in_profile_run: bool = False,
|
||||
reserved_mc2_mask: Optional[torch.Tensor] = None,
|
||||
moe_comm_method: str = "",
|
||||
moe_comm_type: Optional[MoECommType] = None,
|
||||
num_actual_tokens: Optional[int] = None,
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor: Optional[BatchDescriptor] = None):
|
||||
batch_descriptor: Optional[BatchDescriptor] = None,
|
||||
prefetch_stream: torch.npu.Stream = None,
|
||||
model_instance: torch.nn.Module = None):
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc.
|
||||
We add some additional param into forward_context.
|
||||
@@ -82,8 +80,13 @@ def set_ascend_forward_context(
|
||||
batch_descriptor=batch_descriptor,
|
||||
):
|
||||
forward_context = get_forward_context()
|
||||
forward_context.moe_comm_method_name = moe_comm_method + "commimpl"
|
||||
|
||||
from vllm_ascend.ops.moe.moe_comm_method import get_moe_comm_method
|
||||
forward_context.moe_comm_type = moe_comm_type
|
||||
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
|
||||
|
||||
forward_context.with_prefill = with_prefill
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
ep_size = (get_ep_group().world_size if
|
||||
vllm_config.parallel_config.enable_expert_parallel else 1)
|
||||
|
||||
@@ -95,16 +98,63 @@ def set_ascend_forward_context(
|
||||
forward_context.fused_moe_state = fused_moe_state
|
||||
forward_context.in_profile_run = in_profile_run
|
||||
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
|
||||
get_token_dispatcher
|
||||
dispatcher_name = get_dispatcher_name(ep_size, with_prefill)
|
||||
dispatcher = get_token_dispatcher(dispatcher_name)
|
||||
forward_context.token_dispatcher = dispatcher
|
||||
|
||||
# NOTE: This cannot be set using set_forward_context
|
||||
# due to multiple warmups before actual capturing
|
||||
forward_context.capturing = False
|
||||
|
||||
# set for sequence parallelism, 1000 is the batch size concurrency threshold for enabling the flashcomm_v1 or sequence_parallelism feature.
|
||||
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
|
||||
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
|
||||
# the performance may degrade due to the switching of communication methods.
|
||||
sp_enabled = enable_sp(vllm_config) and \
|
||||
tp_world_size > 1 and \
|
||||
num_tokens is not None and num_tokens > 1000
|
||||
|
||||
if sp_enabled:
|
||||
pad_size = (tp_world_size -
|
||||
(num_tokens % tp_world_size)) % tp_world_size
|
||||
forward_context.pad_size = pad_size
|
||||
forward_context.sp_enabled = sp_enabled
|
||||
|
||||
# set this for rope forward_oot using
|
||||
forward_context.is_first_layer = True
|
||||
|
||||
# set layer_idx to enable optimization features that depend on this information.
|
||||
# This is only applicable to models that contain these necessary attributes.
|
||||
forward_context.layer_idx = None
|
||||
if model_instance is not None and \
|
||||
hasattr(model_instance, "model") and \
|
||||
hasattr(model_instance.model, "start_layer"):
|
||||
forward_context.layer_idx = model_instance.model.start_layer
|
||||
|
||||
# set for mlp weight prefetch
|
||||
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
|
||||
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
|
||||
forward_context.layer_idx is not None and \
|
||||
num_tokens is not None and num_tokens < 500
|
||||
if prefetch_mlp_enabled:
|
||||
forward_context.prefetch_stream = prefetch_stream
|
||||
forward_context.model_instance = model_instance
|
||||
forward_context.prefetch_mlp_gate_up_proj = False
|
||||
forward_context.prefetch_mlp_down_proj = False
|
||||
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
|
||||
|
||||
# TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant.
|
||||
# It will be improved later by implementing operator fusion through the FX graph.
|
||||
#
|
||||
# set for addrmsnorm+quant fusion.
|
||||
# this optim now just support dense models due to the specific operators used.
|
||||
# Once the necessary conditions are met, support for MOE models will also be added.
|
||||
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
||||
addrmsnorm_quant_fusion_enabled = isinstance(vllm_config.quant_config, AscendQuantConfig) and \
|
||||
vllm_config.model_config.hf_config.model_type in ["llama", "qwen2", "qwen3"] and \
|
||||
forward_context.layer_idx is not None
|
||||
if addrmsnorm_quant_fusion_enabled:
|
||||
forward_context.model_instance = model_instance
|
||||
forward_context.num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
|
||||
forward_context.fusion_linear = "gate_up_dense" if forward_context.layer_idx == 0 else "qkv_dense"
|
||||
forward_context.addrmsnorm_quant_fusion_enabled = addrmsnorm_quant_fusion_enabled
|
||||
|
||||
if num_tokens is None and attn_metadata is not None:
|
||||
num_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
@@ -120,7 +170,6 @@ def set_ascend_forward_context(
|
||||
if num_tokens is not None:
|
||||
if num_actual_tokens is None:
|
||||
num_actual_tokens = num_tokens
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
# NOTE: token num which need to pad to when mc2
|
||||
forward_context.padded_num_tokens = math.ceil(
|
||||
max_tokens_across_dp / tp_world_size) * tp_world_size
|
||||
|
||||
@@ -39,11 +39,22 @@ class AttentionMaskBuilder:
|
||||
self,
|
||||
max_seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device = None,
|
||||
):
|
||||
# NOTE: The device argument specifies the target NPU
|
||||
# to be used for the newly added FIA operator.
|
||||
# Only pass this parameter when using the new FIA operator.
|
||||
|
||||
attn_mask = _generate_attn_mask(max_seq_len, dtype)
|
||||
|
||||
self._seq_len_cached = attn_mask.shape[0]
|
||||
self.attn_mask_cache = attn_mask
|
||||
self.device = device
|
||||
if torch.version.cann.startswith("8.3"):
|
||||
assigned_mask_dim = 2048
|
||||
self.chunked_prefill_attn_mask = torch.triu(
|
||||
torch.ones(assigned_mask_dim, assigned_mask_dim),
|
||||
diagonal=1).to(torch.int8).to(device)
|
||||
|
||||
@staticmethod
|
||||
def get_mask_scale_factor(dtype: torch.dtype = torch.float16):
|
||||
@@ -62,28 +73,32 @@ class AttentionMaskBuilder:
|
||||
device: torch.device):
|
||||
self._update_attn_cache(max_seq_len, dtype)
|
||||
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
|
||||
).to(device)
|
||||
).to(device, non_blocking=True)
|
||||
|
||||
def get_splitfuse_attn_mask(
|
||||
self,
|
||||
seq_lens: torch.Tensor,
|
||||
position: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
seq_lens: torch.Tensor = None,
|
||||
position: torch.Tensor = None,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
) -> torch.Tensor:
|
||||
if dtype not in [torch.float16, torch.bfloat16]:
|
||||
raise ValueError(
|
||||
"splitfuse_attn_mask now only supports bf16 and fp16")
|
||||
max_seq_len = max(seq_lens, default=0)
|
||||
self._update_attn_cache(max_seq_len, dtype)
|
||||
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
|
||||
# is not the same. Fix this in the future when kernel is ready.
|
||||
mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(dtype)
|
||||
attn_mask = torch.index_select(self.attn_mask_cache,
|
||||
dim=0,
|
||||
index=position)[:, :max_seq_len]
|
||||
attn_mask *= mask_scale_factor
|
||||
return attn_mask.contiguous().to(device, non_blocking=True)
|
||||
if torch.version.cann.startswith("8.3"):
|
||||
return self.chunked_prefill_attn_mask
|
||||
else:
|
||||
if dtype not in [torch.float16, torch.bfloat16]:
|
||||
raise ValueError(
|
||||
"splitfuse_attn_mask now only supports bf16 and fp16")
|
||||
max_seq_len = max(seq_lens, default=0)
|
||||
self._update_attn_cache(max_seq_len, dtype)
|
||||
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
|
||||
# is not the same. Fix this in the future when kernel is ready.
|
||||
mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(
|
||||
dtype)
|
||||
attn_mask = torch.index_select(self.attn_mask_cache,
|
||||
dim=0,
|
||||
index=position)[:, :max_seq_len]
|
||||
attn_mask *= mask_scale_factor
|
||||
return attn_mask.contiguous().to(device, non_blocking=True)
|
||||
|
||||
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
|
||||
if seqlen > self._seq_len_cached:
|
||||
|
||||
@@ -17,24 +17,27 @@
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple, Type
|
||||
from typing import ClassVar, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.utils import cdiv, direct_register_custom_op
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
maybe_save_kv_layer_to_connector,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.compilation.acl_graph import get_graph_params
|
||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
||||
nd_to_nz_2d, nd_to_nz_spec)
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
|
||||
|
||||
class AscendAttentionBackend(AttentionBackend):
|
||||
@@ -52,10 +55,6 @@ class AscendAttentionBackend(AttentionBackend):
|
||||
def get_metadata_cls() -> Type["AscendMetadata"]:
|
||||
return AscendMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
|
||||
return AscendAttentionMetadataBuilder
|
||||
@@ -111,6 +110,10 @@ class AscendAttentionBackend(AttentionBackend):
|
||||
key_caches[dst_indices] = key_caches[src_indices]
|
||||
value_caches[dst_indices] = value_caches[src_indices]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_block_size() -> list[int]:
|
||||
return [64]
|
||||
|
||||
|
||||
class AscendAttentionState(Enum):
|
||||
PrefillNoCache = 0
|
||||
@@ -155,48 +158,50 @@ class AscendMetadata:
|
||||
|
||||
# *************************** Other Properties *************************** #
|
||||
enable_dbo_across_dp: bool = False
|
||||
is_only_prefill: bool = False
|
||||
|
||||
|
||||
class AscendAttentionMetadataBuilder:
|
||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
# Does this backend/builder reorder the batch?
|
||||
# If not, set this to None. Otherwise set it to the query
|
||||
# length that will be pulled into the front of the batch.
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.device = device
|
||||
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
|
||||
vllm_config.cache_config.block_size)
|
||||
self.max_num_blocks_per_req = cdiv(
|
||||
self.model_config.max_model_len,
|
||||
AscendAttentionBackend.get_supported_block_size()[0])
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
def reorder_batch(self, input_batch,
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return False
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
model: Optional[nn.Module] = None,
|
||||
):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||
num_reqs
|
||||
+ 1]
|
||||
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
block_table[:num_reqs, :self.max_num_blocks_per_req] = (
|
||||
block_table[:num_reqs])
|
||||
|
||||
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
|
||||
num_actual_tokens].to(
|
||||
self.device,
|
||||
non_blocking=
|
||||
True)
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
||||
attn_mask = common_attn_metadata.attn_mask
|
||||
attn_state = common_attn_metadata.attn_state
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||
@@ -225,8 +230,25 @@ class AscendAttentionMetadataBuilder:
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state,
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
|
||||
is_only_prefill=common_attn_metadata.is_only_prefill)
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp)
|
||||
return attn_metadata
|
||||
|
||||
def build_for_graph_capture(
|
||||
self,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
|
||||
):
|
||||
if attn_state == AscendAttentionState.DecodeOnly:
|
||||
attn_metadata = self.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Currently we only support building dummy metadata for DecodeOnly state"
|
||||
)
|
||||
|
||||
attn_metadata.attn_state = attn_state
|
||||
return attn_metadata
|
||||
|
||||
|
||||
@@ -265,20 +287,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
self.key_cache = None
|
||||
self.value_cache = None
|
||||
|
||||
def _repeat_kv(self, hidden_states: torch.Tensor,
|
||||
n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, None, :, :].expand(
|
||||
num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(num_key_value_heads * n_rep, slen,
|
||||
head_dim)
|
||||
|
||||
def _forward_prefill_no_cache(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
@@ -304,34 +312,15 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
mask = torch_npu.npu_format_cast(mask.contiguous(),
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
if self.sliding_window is not None and \
|
||||
attn_metadata.attn_mask.shape[0] > self.sliding_window:
|
||||
|
||||
key = self._repeat_kv(key, self.num_heads // self.num_kv_heads)
|
||||
value = self._repeat_kv(value, self.num_heads // self.num_kv_heads)
|
||||
|
||||
output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="TND",
|
||||
pre_tokens=self.sliding_window,
|
||||
scale=self.scale,
|
||||
actual_seq_lengths=attn_metadata.seq_lens,
|
||||
actual_seq_lengths_kv=attn_metadata.seq_lens)
|
||||
output = output.view(num_tokens, self.num_heads, self.head_size)
|
||||
else:
|
||||
torch_npu._npu_flash_attention(query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seq_len=attn_metadata.seq_lens,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output)
|
||||
torch_npu._npu_flash_attention(query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seq_len=attn_metadata.seq_lens,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output)
|
||||
assert output is not None
|
||||
return output[:num_tokens, :, :]
|
||||
|
||||
@@ -372,7 +361,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
# seq_lens_tensor needs to be transferred to the device for 310P.
|
||||
attn_metadata.seq_lens = \
|
||||
attn_metadata.seq_lens.to(device=query.device)
|
||||
if self.sliding_window is not None:
|
||||
if self.sliding_window is not None and attn_metadata.seq_lens.shape[
|
||||
0] == query.size(0):
|
||||
batch_size = attn_metadata.seq_lens.shape[0]
|
||||
block_size = 128
|
||||
query = query.view(batch_size, 1, self.num_heads * self.head_size)
|
||||
@@ -399,16 +389,53 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
output = output.view(batch_size, self.num_heads, self.head_size)
|
||||
else:
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
graph_params = get_graph_params()
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
num_tokens = query.shape[0]
|
||||
if forward_context.capturing:
|
||||
stream = torch_npu.npu.current_stream()
|
||||
|
||||
event = torch.npu.ExternalEvent()
|
||||
event.wait(stream)
|
||||
event.reset(stream)
|
||||
graph_params.events[num_tokens].append(event)
|
||||
|
||||
graph_params.attn_params[num_tokens].append((
|
||||
query,
|
||||
self.key_cache,
|
||||
self.value_cache,
|
||||
self.num_kv_heads,
|
||||
self.num_heads,
|
||||
self.scale,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.seq_lens,
|
||||
output,
|
||||
))
|
||||
|
||||
torch.npu.graph_task_group_begin(stream)
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
handle = torch.npu.graph_task_group_end(stream)
|
||||
graph_params.handles[num_tokens].append(handle)
|
||||
else:
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
return output
|
||||
|
||||
def _forward_v1_style(
|
||||
@@ -449,18 +476,43 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
attn_metadata.seq_lens = \
|
||||
attn_metadata.seq_lens.to(device=query.device)
|
||||
|
||||
torch_npu._npu_paged_attention_splitfuse(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
mask=attn_metadata.attn_mask,
|
||||
block_table=attn_metadata.block_tables,
|
||||
seq_len=attn_metadata.query_lens,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output)
|
||||
if torch.version.cann.startswith("8.3"):
|
||||
# TODO:The npu_fused_infer_attention_score op is planned to
|
||||
# be utilized in a wider range in upcoming versions.
|
||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||
key = self.key_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
value = self.value_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
|
||||
output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
block_table=attn_metadata.block_tables,
|
||||
input_layout="TND",
|
||||
block_size=block_size,
|
||||
actual_seq_lengths=attn_metadata.query_start_loc[1:],
|
||||
actual_seq_lengths_kv=attn_metadata.seq_lens,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale=self.scale,
|
||||
sparse_mode=3,
|
||||
)
|
||||
else:
|
||||
torch_npu._npu_paged_attention_splitfuse(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
mask=attn_metadata.attn_mask,
|
||||
block_table=attn_metadata.block_tables,
|
||||
seq_len=attn_metadata.query_lens,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output)
|
||||
return output
|
||||
|
||||
def forward(
|
||||
@@ -554,12 +606,18 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
output)
|
||||
# Normal V1 situation.
|
||||
else:
|
||||
if torch.version.cann.startswith("8.3"):
|
||||
# npu_fused_infer_attention_score does not support cases
|
||||
# where query.shape[0] != attn_metadata.query_start_loc[-1].
|
||||
# Thus we need unpad it here.
|
||||
num_tokens = attn_metadata.query_start_loc[-1]
|
||||
query = query[:num_tokens]
|
||||
output = self._forward_v1_style(query, attn_metadata, output)
|
||||
|
||||
# to make in-place change to the output tensor
|
||||
if hasattr(layer, 'quant_method') and use_kv_cache_int8:
|
||||
output = output.view(num_tokens, self.num_heads, self.head_size)
|
||||
ori_output[:, :, :] = output[:num_tokens, :, :]
|
||||
ori_output[:num_tokens, :, :] = output[:num_tokens, :, :]
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
|
||||
|
||||
@@ -570,8 +628,11 @@ def unified_ascend_attention_with_output(
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(self,
|
||||
@@ -582,6 +643,7 @@ def unified_ascend_attention_with_output(
|
||||
attn_metadata,
|
||||
output,
|
||||
trace_flag=False)
|
||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||
return
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar
|
||||
from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
|
||||
TypeVar)
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
@@ -12,15 +13,17 @@ from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.utils import cdiv, round_down
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
maybe_save_kv_layer_to_connector,
|
||||
split_decodes_and_prefills,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
|
||||
from vllm_ascend.utils import npu_prefetch
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
|
||||
@@ -164,6 +167,9 @@ M = TypeVar("M", bound=AscendMLAMetadata)
|
||||
|
||||
|
||||
class AscendMLAMetadataBuilder:
|
||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.NEVER
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
@@ -171,6 +177,8 @@ class AscendMLAMetadataBuilder:
|
||||
|
||||
# _attn_mask_builder = None
|
||||
def __init__(self,
|
||||
kv_cache_spec,
|
||||
layer_names,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
metadata_cls: Optional[AscendMLAMetadata] = None):
|
||||
@@ -185,7 +193,16 @@ class AscendMLAMetadataBuilder:
|
||||
self.block_size - 1) // self.block_size
|
||||
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
||||
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.decode_threshold = 1
|
||||
if self.speculative_config:
|
||||
spec_token_num = self.speculative_config.num_speculative_tokens
|
||||
self.decode_threshold += spec_token_num
|
||||
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
|
||||
npu_fused_infer_attention_score TND layout's limit of 16, \
|
||||
got {self.decode_threshold}"
|
||||
|
||||
self.reorder_batch_threshold = self.decode_threshold
|
||||
|
||||
if self.chunked_prefill_enabled:
|
||||
self.chunked_prefill_workspace_size = min(
|
||||
@@ -265,6 +282,7 @@ class AscendMLAMetadataBuilder:
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
) -> AscendMLAMetadata:
|
||||
@@ -272,7 +290,6 @@ class AscendMLAMetadataBuilder:
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
# TODO(xyx): remove the if condition after mla supports torch mode speculative decoding
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
@@ -284,11 +301,7 @@ class AscendMLAMetadataBuilder:
|
||||
device = self.device
|
||||
|
||||
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
|
||||
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
|
||||
num_actual_tokens].to(
|
||||
device,
|
||||
non_blocking=
|
||||
True)
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
||||
input_positions = common_attn_metadata.positions[:
|
||||
num_actual_tokens].long(
|
||||
)
|
||||
@@ -376,11 +389,12 @@ class AscendMLAMetadataBuilder:
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
|
||||
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
|
||||
max_seq_lens = seq_lens[:num_decodes].max().item()
|
||||
seq_lens = seq_lens[:num_decode_tokens]
|
||||
seq_lens = seq_lens[:num_decodes]
|
||||
input_positions = input_positions[:num_decode_tokens]
|
||||
block_table = block_table[:num_decode_tokens, ...]
|
||||
block_table = block_table[:num_decodes, ...]
|
||||
seq_lens_list = seq_lens.tolist()
|
||||
|
||||
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
|
||||
@@ -481,17 +495,12 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
self.enable_prefetch = ascend_config.enable_prefetch
|
||||
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
||||
self.chunked_prefill_for_mla = ascend_config.chunked_prefill_for_mla
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.ring_mla_mask_size = 512
|
||||
self.prefill_mask = None
|
||||
|
||||
# Adapt torch air graph mode with spec decoding.
|
||||
speculative_config = vllm_config.speculative_config
|
||||
if speculative_config is not None:
|
||||
self.spec_token_num = speculative_config.num_speculative_tokens
|
||||
assert self.spec_token_num > 0
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
|
||||
def _v_up_proj(self, x):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
@@ -663,84 +672,47 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.v_head_dim,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
query = torch.cat((q_nope, q_pe), dim=-1)
|
||||
key = torch.cat((k_nope, k_pe), dim=-1)
|
||||
torch_npu._npu_flash_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=attn_metadata.attn_mask,
|
||||
seq_len=attn_metadata.prefill.context_lens,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_heads,
|
||||
out=attn_output)
|
||||
elif self.chunked_prefill_for_mla:
|
||||
attn_lse = torch.empty(self.num_heads,
|
||||
num_tokens,
|
||||
dtype=torch.float32,
|
||||
device=q_nope.device)
|
||||
if self.prefill_mask is None:
|
||||
self.prefill_mask = torch.triu(
|
||||
torch.ones(self.ring_mla_mask_size,
|
||||
self.ring_mla_mask_size,
|
||||
device=q_nope.device,
|
||||
dtype=q_nope.dtype), 1)
|
||||
torch_npu.atb.npu_ring_mla(
|
||||
q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope,
|
||||
k_rope=k_pe,
|
||||
value=value,
|
||||
mask=self.prefill_mask,
|
||||
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
|
||||
dtype=torch.int32),
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=None,
|
||||
prev_lse=None,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="mask_type_triu",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_first_ring",
|
||||
output=attn_output,
|
||||
softmax_lse=attn_lse)
|
||||
attn_output, attn_lse = self._compute_prefill_context( \
|
||||
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
|
||||
else:
|
||||
query = torch.cat((q_nope, q_pe), dim=-1)
|
||||
attn_output_torch = torch.empty(num_tokens,
|
||||
self.num_heads * self.v_head_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
# current requests is chunked in prefill, disable flash attention with chunked prefill
|
||||
vanilla_chunked_prefill_mla(
|
||||
output=attn_output_torch,
|
||||
query=query,
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
block_tables=attn_metadata.prefill.block_table,
|
||||
query_lens=attn_metadata.prefill.query_lens,
|
||||
context_lens=attn_metadata.prefill.context_lens,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
max_query_len=attn_metadata.prefill.max_query_len,
|
||||
max_context_len=attn_metadata.prefill.max_seq_lens,
|
||||
nope_dim=self.qk_nope_head_dim,
|
||||
rope_dim=self.qk_rope_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
scale=self.scale,
|
||||
alibi_slopes=None,
|
||||
causal=True)
|
||||
attn_lse = torch.empty(self.num_heads,
|
||||
num_tokens,
|
||||
dtype=torch.float32,
|
||||
device=q_nope.device)
|
||||
if self.prefill_mask is None:
|
||||
if q_nope.dtype == torch.float16:
|
||||
mask_value = torch.finfo(torch.float32).min
|
||||
else:
|
||||
mask_value = 1
|
||||
prefill_mask = torch.triu(
|
||||
torch.ones(self.ring_mla_mask_size,
|
||||
self.ring_mla_mask_size,
|
||||
device=q_nope.device,
|
||||
dtype=q_nope.dtype), 1)
|
||||
self.prefill_mask = torch.where(prefill_mask == 1, mask_value,
|
||||
0).to(q_nope.dtype)
|
||||
torch_npu.atb.npu_ring_mla(q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope,
|
||||
k_rope=k_pe,
|
||||
value=value,
|
||||
mask=self.prefill_mask,
|
||||
seqlen=torch.tensor(
|
||||
attn_metadata.prefill.query_lens,
|
||||
dtype=torch.int32),
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=None,
|
||||
prev_lse=None,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="mask_type_triu",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_first_ring",
|
||||
output=attn_output,
|
||||
softmax_lse=attn_lse)
|
||||
attn_output, attn_lse = self._compute_prefill_context( \
|
||||
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
|
||||
|
||||
attn_output = attn_output.reshape(
|
||||
[num_tokens, self.num_heads * self.v_head_dim])
|
||||
if attn_metadata.attn_state in [
|
||||
AscendAttentionState.ChunkedPrefill,
|
||||
AscendAttentionState.SpecDecoding,
|
||||
AscendAttentionState.PrefillCacheHit
|
||||
] and not self.chunked_prefill_for_mla:
|
||||
attn_output = attn_output_torch
|
||||
return attn_output
|
||||
|
||||
def exec_kv_decode(
|
||||
@@ -785,7 +757,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
||||
kv_no_split = kv_no_split.view(
|
||||
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA"
|
||||
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
|
||||
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||
kv_no_split,
|
||||
self.kv_a_layernorm.weight,
|
||||
@@ -840,8 +812,11 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.qk_rope_head_dim)
|
||||
input_layout = "BNSD"
|
||||
|
||||
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
|
||||
assert num_tokens % self.spec_token_num == 0
|
||||
if attn_metadata.attn_state in [
|
||||
AscendAttentionState.SpecDecoding,
|
||||
AscendAttentionState.ChunkedPrefill
|
||||
] and self.speculative_config is not None:
|
||||
# Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill
|
||||
input_layout = "TND"
|
||||
# [bs * q_seq_len, num_heads_per_rank, dim]
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, -1)
|
||||
@@ -887,8 +862,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
current_ms_metadata.before_comm_event.wait()
|
||||
return self._v_up_proj(attn_output)
|
||||
|
||||
def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata,
|
||||
need_gather_q_kv):
|
||||
def _mla_preprocess(self, layer_name, hidden_states, kv_cache,
|
||||
attn_metadata, need_gather_q_kv):
|
||||
# MLA Preprocess:
|
||||
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
|
||||
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
|
||||
@@ -917,6 +892,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
|
||||
decode_preprocess_res = None
|
||||
prefill_preprocess_res = None
|
||||
if has_prefill:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
# Preprocess for decode tokens
|
||||
if has_decode:
|
||||
decode_q_c = q_c[:num_decode_tokens]
|
||||
@@ -963,6 +940,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer_name,
|
||||
hidden_states: torch.Tensor, # query in unified attn
|
||||
kv_cache: Tuple[torch.Tensor],
|
||||
attn_metadata: M,
|
||||
@@ -989,7 +967,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
|
||||
# MLA Preprocess
|
||||
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess(
|
||||
hidden_states, kv_cache, attn_metadata, need_gather_q_kv)
|
||||
layer_name, hidden_states, kv_cache, attn_metadata,
|
||||
need_gather_q_kv)
|
||||
|
||||
if decode_preprocess_res is not None:
|
||||
# MLA Preprocess for decoding
|
||||
@@ -1047,4 +1026,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
is_force_scatter=self.enable_shared_expert_dp)[0]
|
||||
current_ms_metadata.after_comm_event.record()
|
||||
del o_proj_input
|
||||
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
if has_prefill:
|
||||
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
|
||||
return output_padded
|
||||
|
||||
986
vllm_ascend/attention/sfa_v1.py
Normal file
986
vllm_ascend/attention/sfa_v1.py
Normal file
@@ -0,0 +1,986 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
|
||||
TypeVar)
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl)
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.utils import cdiv, round_down
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
|
||||
class AscendSFABackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ASCEND_SFA"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
return AscendSFAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls():
|
||||
return AscendSFAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int,
|
||||
head_size: int) -> tuple[int, ...]:
|
||||
return (num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["AscendSFAImpl"]:
|
||||
return AscendSFAImpl
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendSFAPrefillMetadata:
|
||||
""" Prefill Specific Metadata for Ascend"""
|
||||
|
||||
@dataclass
|
||||
class ChunkedContextMetadata:
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# For handling chunked prefill
|
||||
cu_seq_lens: torch.Tensor
|
||||
starts: torch.Tensor
|
||||
seq_tot: list[int]
|
||||
max_seq_lens: list[int]
|
||||
workspace: torch.Tensor
|
||||
chunk_seq_lens: torch.Tensor
|
||||
|
||||
attn_mask: torch.Tensor
|
||||
query_lens: list[int]
|
||||
seq_lens: list[int]
|
||||
|
||||
context_lens: torch.Tensor
|
||||
input_positions: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
max_query_len: int
|
||||
max_seq_lens: int
|
||||
sin: torch.Tensor
|
||||
cos: torch.Tensor
|
||||
chunked_context: Optional[ChunkedContextMetadata] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendSFADecodeMetadata:
|
||||
# Input positions for rotrary embeddings since for MLA the rotary
|
||||
# position embeddings are applied inside the attention backend
|
||||
input_positions: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
max_seq_lens: int
|
||||
seq_lens_list: list[int]
|
||||
actual_seq_lengths_q: torch.Tensor
|
||||
sin: torch.Tensor
|
||||
cos: torch.Tensor
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendSFAMetadata:
|
||||
"""Metadata for MLACommon.
|
||||
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
slot_mapping: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
block_tables: torch.Tensor
|
||||
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# For handling prefill decode split
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
|
||||
# For logging.
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
|
||||
query_lens: Optional[list[int]] = None
|
||||
# The dimension of the attention heads
|
||||
head_dim: Optional[int] = None
|
||||
attn_mask: torch.Tensor = None
|
||||
# chunked prefill by default if no attn_states passed
|
||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||
|
||||
decode: Optional[AscendSFADecodeMetadata] = None
|
||||
prefill: Optional[AscendSFAPrefillMetadata] = None
|
||||
enable_dbo_across_dp: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
# supported_head_sizes = AscendMLABackend.get_supported_head_sizes()
|
||||
# if self.head_dim is not None and self.head_dim \
|
||||
# not in supported_head_sizes:
|
||||
# raise ValueError(
|
||||
# f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
# f"received {self.head_dim}.")
|
||||
|
||||
def split_metadata_for_multistream(
|
||||
self,
|
||||
ms_split_config: MSAttentionMetadataSplitConfig,
|
||||
) -> list["AscendSFAMetadata"]:
|
||||
"""Split metadata for multi-stream with AscendSFAMetadata"""
|
||||
return model_input_split_v1_mla_attn(
|
||||
ms_split_config=ms_split_config,
|
||||
attn_metadata=self,
|
||||
_metadata_cls=AscendMLAMetadata,
|
||||
)
|
||||
|
||||
|
||||
M = TypeVar("M", bound=AscendSFAMetadata)
|
||||
|
||||
|
||||
class AscendSFAMetadataBuilder:
|
||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.NEVER
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
|
||||
# _attn_mask_builder = None
|
||||
def __init__(self,
|
||||
kv_cache_spec,
|
||||
layer_names,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
metadata_cls: Optional[AscendSFAMetadata] = None):
|
||||
self.metadata_cls: Optional[AscendSFAMetadata] = metadata_cls \
|
||||
if metadata_cls is not None else AscendSFAMetadata # type: ignore
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.device = device
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.max_blocks = (vllm_config.model_config.max_model_len +
|
||||
self.block_size - 1) // self.block_size
|
||||
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
||||
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.decode_threshold = 1
|
||||
if self.speculative_config:
|
||||
spec_token_num = self.speculative_config.num_speculative_tokens
|
||||
self.decode_threshold += spec_token_num
|
||||
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
|
||||
npu_fused_infer_attention_score TND layout's limit of 16, \
|
||||
got {self.decode_threshold}"
|
||||
|
||||
if self.chunked_prefill_enabled:
|
||||
self.chunked_prefill_workspace_size = min(
|
||||
# Max sure there is enough for 8 full length request or at least
|
||||
# 4 pages of cache per request
|
||||
max(8 * self.model_config.max_model_len,
|
||||
4 * scheduler_config.max_num_seqs * self.block_size),
|
||||
# For long-context models try not to over-allocate limiting
|
||||
# kv-cache space, limiting it to 64k tokens,
|
||||
# which would result in the workspace being:
|
||||
# 2*(576)*(64*1024) = 144mb
|
||||
# (assuming 576 MLA head dim, and fp16)
|
||||
# which would result in up-projected context being
|
||||
# 2*(192*128)*(64*1024) = 3gb
|
||||
# (assuming 192 QK head dim, 128 heads, and fp16)
|
||||
128 * 1024)
|
||||
assert self.chunked_prefill_workspace_size >= \
|
||||
scheduler_config.max_num_seqs * self.block_size
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(self.chunked_prefill_workspace_size,
|
||||
self.model_config.get_head_size()),
|
||||
dtype=self.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||
self.cos_cache = None
|
||||
self.sin_cache = None
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
# We now want to reorder the batch so that the "decode" requests are at
|
||||
# the front and the "prefill" requests are at the using the least amount
|
||||
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
||||
# where attention is likely memory-bound and "prefill" to mean requests
|
||||
# where attention is likely compute-bound, TODO(lucas): figure out a
|
||||
# better naming here)
|
||||
decodes = []
|
||||
prefills = []
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
if num_tokens <= self.decode_threshold:
|
||||
decodes.append(i)
|
||||
else:
|
||||
prefills.append(i)
|
||||
|
||||
# We hope that this is fairly minimal since decodes
|
||||
# should be around for a number of iterations so hopefully they are
|
||||
# relatively stationary (and new request are generally appended to the
|
||||
# persistent batch so already should be at the back)
|
||||
# To achieve this we loop over the decodes in descending order and
|
||||
# the prefills in ascending order. We swap decodes from the "back"
|
||||
# i.e. past where the last decode should be in the reodorered with
|
||||
# prefills from the front of the batch.
|
||||
# `decodes` and `prefills` are already in ascending order just based on
|
||||
# the above loop
|
||||
num_decodes = len(decodes)
|
||||
num_prefills = len(prefills)
|
||||
first_prefill = 0
|
||||
modified_batch = False
|
||||
|
||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
# If the decode is at the "back" of the batch, i, we can swap it
|
||||
# with the prefill closest to the front of the batch
|
||||
if decodes[num_decodes - i] >= num_decodes:
|
||||
input_batch.swap_states(prefills[first_prefill],
|
||||
decodes[num_decodes - i])
|
||||
first_prefill += 1
|
||||
modified_batch = True
|
||||
else:
|
||||
break
|
||||
|
||||
# Save for next `build` call
|
||||
# TODO(lucas): this is a bit of a hack, we should probably have a
|
||||
# better way of doing this
|
||||
return modified_batch
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
) -> AscendSFAMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
|
||||
|
||||
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
||||
# function. We should avoid GPU -> CPU sync as much as possible because
|
||||
# it blocks on all previous kernels.
|
||||
device = self.device
|
||||
|
||||
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:
|
||||
num_actual_tokens].to(
|
||||
device,
|
||||
non_blocking=True)
|
||||
input_positions = common_attn_metadata.positions[:
|
||||
num_actual_tokens].long(
|
||||
)
|
||||
|
||||
if self.cos_cache is None:
|
||||
self.cos_cache = model.model.layers[
|
||||
0].self_attn.rotary_emb.cos_cached
|
||||
self.sin_cache = model.model.layers[
|
||||
0].self_attn.rotary_emb.sin_cached
|
||||
if self.cos_cache.dtype != self.model_config.dtype: # type: ignore
|
||||
self.cos_cache = self.cos_cache.to( # type: ignore
|
||||
self.model_config.dtype) # type: ignore
|
||||
self.sin_cache = self.sin_cache.to( # type: ignore
|
||||
self.model_config.dtype) # type: ignore
|
||||
|
||||
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
query_lens = query_seq_lens_cpu[:num_reqs]
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
num_computed_tokens_cpu = (seq_lens - query_lens)
|
||||
|
||||
prefill_metadata = None
|
||||
chunked_context_metadata = None
|
||||
if num_prefills > 0:
|
||||
reqs_start = num_decodes # prefill_start
|
||||
tokens_start = num_decode_tokens
|
||||
max_query_len = query_lens[reqs_start:].max().item()
|
||||
max_seq_lens = seq_lens[reqs_start:].max().item()
|
||||
prefill_query_start_loc = query_start_loc[
|
||||
reqs_start:] - query_start_loc[reqs_start]
|
||||
|
||||
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
|
||||
max_context_len_cpu = context_lens_cpu.max().item()
|
||||
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
||||
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
|
||||
max_context_chunk = (self.chunked_prefill_workspace_size //
|
||||
num_prefills_with_context_cpu)
|
||||
max_context_chunk = round_down(max_context_chunk,
|
||||
self.block_size)
|
||||
|
||||
assert max_context_chunk > 0
|
||||
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
|
||||
chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
|
||||
.unsqueeze(1).expand(-1, num_prefills) * max_context_chunk
|
||||
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
|
||||
chunk_starts + max_context_chunk)
|
||||
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
||||
cu_seq_lens_cpu = torch.zeros(num_chunks,
|
||||
num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
pin_memory=True)
|
||||
torch.cumsum(chunk_seq_lens,
|
||||
dim=1,
|
||||
out=cu_seq_lens_cpu[:, 1:],
|
||||
dtype=torch.int32)
|
||||
chunked_context_metadata = \
|
||||
AscendSFAPrefillMetadata.ChunkedContextMetadata(
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
||||
starts=chunk_starts.to(device, non_blocking=True),
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
chunk_seq_lens=chunk_seq_lens,
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
)
|
||||
prefill_input_positions = input_positions[tokens_start:]
|
||||
cos = self.cos_cache[
|
||||
prefill_input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
sin = self.sin_cache[
|
||||
prefill_input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
actual_query_lens = torch.tensor(query_lens[reqs_start:],
|
||||
dtype=torch.int32).npu()
|
||||
query_lens_prefill_sfa = torch.cumsum(actual_query_lens,
|
||||
dim=0).to(torch.int32)
|
||||
seq_lens_prefill_sfa = seq_lens[reqs_start:].to(torch.int32).npu()
|
||||
prefill_metadata = AscendSFAPrefillMetadata(
|
||||
attn_mask=common_attn_metadata.attn_mask,
|
||||
query_lens=query_lens_prefill_sfa,
|
||||
seq_lens=seq_lens_prefill_sfa,
|
||||
context_lens=seq_lens[reqs_start:],
|
||||
input_positions=prefill_input_positions,
|
||||
block_table=block_table[reqs_start:, ...],
|
||||
max_query_len=max_query_len,
|
||||
max_seq_lens=max_seq_lens,
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
chunked_context=chunked_context_metadata,
|
||||
sin=sin,
|
||||
cos=cos,
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
|
||||
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].to(
|
||||
torch.int32).npu()
|
||||
max_seq_lens = seq_lens[:num_decodes].max().item()
|
||||
seq_lens = seq_lens[:num_decodes].to(torch.int32).npu()
|
||||
input_positions = input_positions[:num_decode_tokens]
|
||||
block_table = block_table[:num_decodes, ...]
|
||||
seq_lens_list = seq_lens.tolist()
|
||||
|
||||
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
|
||||
decode_metadata = AscendSFADecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens_list,
|
||||
max_seq_lens=max_seq_lens,
|
||||
attn_mask=common_attn_metadata.spec_attn_mask,
|
||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||
sin=sin,
|
||||
cos=cos)
|
||||
|
||||
return self.metadata_cls( # type: ignore
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
query_lens=query_lens.tolist(),
|
||||
slot_mapping=slot_mapping,
|
||||
head_dim=self.model_config.get_head_size(),
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
attn_mask=common_attn_metadata.attn_mask,
|
||||
attn_state=common_attn_metadata.attn_state,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
query_start_loc=query_start_loc,
|
||||
block_tables=block_table,
|
||||
seq_lens=seq_lens,
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
|
||||
)
|
||||
|
||||
|
||||
class PrefillSFAPreprocessResult(NamedTuple):
|
||||
q_nope: Optional[torch.Tensor] = None
|
||||
q_pe: Optional[torch.Tensor] = None
|
||||
k_nope: Optional[torch.Tensor] = None
|
||||
k_pe: Optional[torch.Tensor] = None
|
||||
topk_indices: Optional[torch.Tensor] = None
|
||||
query_states: Optional[torch.Tensor] = None
|
||||
key_states: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class DecodeSFAPreprocessResult(NamedTuple):
|
||||
q_nope: Optional[torch.Tensor] = None
|
||||
q_pe: Optional[torch.Tensor] = None
|
||||
# nope_cache: Optional[torch.Tensor] = None
|
||||
# rope_cache: Optional[torch.Tensor] = None
|
||||
topk_indices: Optional[torch.Tensor] = None
|
||||
query_states: Optional[torch.Tensor] = None
|
||||
key_states: Optional[torch.Tensor] = None
|
||||
bsz: Optional[int] = None
|
||||
|
||||
|
||||
class AscendSFAImpl(MLAAttentionImpl):
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
# MLA Args
|
||||
self.q_lora_rank = kwargs['q_lora_rank']
|
||||
self.kv_lora_rank = kwargs['kv_lora_rank']
|
||||
self.qk_nope_head_dim = kwargs['qk_nope_head_dim']
|
||||
self.qk_rope_head_dim = kwargs['qk_rope_head_dim']
|
||||
self.qk_head_dim = kwargs['qk_head_dim']
|
||||
self.v_head_dim = kwargs['v_head_dim']
|
||||
self.rotary_emb = kwargs['rotary_emb']
|
||||
self.q_proj = kwargs['q_proj']
|
||||
self.kv_b_proj = kwargs['kv_b_proj']
|
||||
self.o_proj = kwargs['o_proj']
|
||||
self.indexer = kwargs['indexer']
|
||||
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
|
||||
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
|
||||
self.q_a_proj = kwargs.get('q_a_proj', None)
|
||||
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_rank = self.num_heads // self.tp_size
|
||||
if self.q_a_proj is not None:
|
||||
self.q_b_proj = self.q_proj
|
||||
else:
|
||||
self.q_b_proj = None
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
self.enable_prefetch = ascend_config.enable_prefetch
|
||||
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.ring_mla_mask_size = 512
|
||||
self.prefill_mask = None
|
||||
|
||||
# indexer param
|
||||
self.dim = self.indexer.dim
|
||||
self.n_heads: int = self.indexer.n_heads # 64
|
||||
self.head_dim: int = self.indexer.head_dim # 128
|
||||
self.index_topk: int = self.indexer.index_topk # 2048
|
||||
self.wq_b = self.indexer.wq_b
|
||||
self.wk = self.indexer.wk
|
||||
self.weights_proj = self.indexer.weights_proj
|
||||
self.k_norm = self.indexer.k_norm
|
||||
self.softmax_scale = self.indexer.softmax_scale
|
||||
|
||||
# Adapt torch air graph mode with spec decoding.
|
||||
speculative_config = vllm_config.speculative_config
|
||||
if speculative_config is not None:
|
||||
self.spec_token_num = speculative_config.num_speculative_tokens
|
||||
assert self.spec_token_num > 0
|
||||
|
||||
self.cp_size = 1
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
|
||||
def get_layer_weight(layer):
|
||||
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
||||
for attr in WEIGHT_NAMES:
|
||||
if hasattr(layer, attr):
|
||||
return getattr(layer, attr)
|
||||
raise AttributeError(
|
||||
f"Layer '{layer}' has no recognized weight attribute:"
|
||||
f" {WEIGHT_NAMES}.")
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
|
||||
# NOTE: This should only be used offline, since it's O(N^3)
|
||||
eye = torch.eye(layer.input_size_per_partition,
|
||||
dtype=act_dtype,
|
||||
device=get_layer_weight(layer).device)
|
||||
dequant_weights = layer.quant_method.apply(layer,
|
||||
eye,
|
||||
bias=None)
|
||||
del eye
|
||||
# standardize to (output, input)
|
||||
return dequant_weights.T
|
||||
return layer.weight
|
||||
|
||||
# we currently do not have quantized bmm's which are needed for
|
||||
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
|
||||
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
|
||||
f"{kv_b_proj_weight.shape=}, "
|
||||
f"{self.kv_lora_rank=}, "
|
||||
f"{self.num_heads=}, "
|
||||
f"{self.qk_nope_head_dim=}, "
|
||||
f"{self.v_head_dim=}")
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
)
|
||||
|
||||
self.kv_b_proj_w_k, self.kv_b_proj_w_v = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
# Convert from (L, N, V) to (N, L, V)
|
||||
self.kv_b_proj_w_v = self.kv_b_proj_w_v.transpose(0, 1).contiguous()
|
||||
# Convert from (L, N, P) to (N, P, L)
|
||||
self.kv_b_proj_w_k = self.kv_b_proj_w_k.permute(1, 2, 0).contiguous()
|
||||
|
||||
# Waiting for BMM NZ support
|
||||
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
|
||||
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
|
||||
|
||||
def _sfa_preprocess(self, hidden_states, kv_cache, attn_metadata,
|
||||
need_gather_q_kv):
|
||||
# SFA Preprocess:
|
||||
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
|
||||
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
|
||||
# 3. If need_gather_q_kv, perform all_gather.
|
||||
# 4. Preprocess decode tokens, write kv cache and get:
|
||||
# decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope
|
||||
# 5. Preprocess prefill tokens, write kv cache and get:
|
||||
# prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
if need_gather_q_kv:
|
||||
# q_c = get_tp_group().all_gather(q_c, 0)
|
||||
# kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
|
||||
hidden_states = get_tp_group().all_gather(hidden_states, 0)
|
||||
# hidden_states_decode = hidden_states[:num_decode_tokens]
|
||||
# if self.q_a_proj is not None:
|
||||
# npu_prefetch(self.q_a_proj.weight,
|
||||
# hidden_states,
|
||||
# enabled=self.enable_prefetch)
|
||||
# ckq = self.q_a_proj(hidden_states) # q down
|
||||
# q_c = self.q_a_layernorm(ckq) # q down layernorm
|
||||
# else:
|
||||
# q_c = hidden_states
|
||||
|
||||
# kv_no_split = self.kv_a_proj_with_mqa(hidden_states) # c_kv
|
||||
# Process for shared_expert_dp
|
||||
|
||||
decode_preprocess_res = None
|
||||
prefill_preprocess_res = None
|
||||
# Preprocess for decode tokens
|
||||
if has_decode:
|
||||
q_len = 1
|
||||
hidden_states_decode = hidden_states[:num_decode_tokens]
|
||||
decode_kq = self.q_a_proj(hidden_states_decode) # q down
|
||||
decode_q_c = self.q_a_layernorm(decode_kq) # q down layernorm
|
||||
decode_kv_no_split = self.kv_a_proj_with_mqa(
|
||||
hidden_states_decode) # c_kv
|
||||
|
||||
# decode_q_c = q_c[:num_decode_tokens]
|
||||
decode_slot_mapping = attn_metadata.slot_mapping[:
|
||||
num_decode_tokens]
|
||||
# decode_kv_no_split = decode_kv_no_split[:num_decode_tokens]
|
||||
|
||||
decode_q = self.q_b_proj(decode_q_c)
|
||||
bsz, _ = decode_q.shape
|
||||
decode_q = decode_q.view(bsz, self.num_heads, 1, self.qk_head_dim)
|
||||
decode_q_nope, decode_q_pe = torch.split(
|
||||
decode_q, [self.qk_nope_head_dim, self.qk_rope_head_dim],
|
||||
dim=-1)
|
||||
decode_q_nope = decode_q_nope.view(
|
||||
-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1)
|
||||
decode_q_nope = (torch.matmul(decode_q_nope,
|
||||
self.kv_b_proj_w_k).transpose(
|
||||
1,
|
||||
0).view(bsz, q_len,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank))
|
||||
|
||||
# stream2 kv
|
||||
key_cache = kv_cache[0]
|
||||
value_cache = kv_cache[1]
|
||||
cos = attn_metadata.decode.cos
|
||||
sin = attn_metadata.decode.sin
|
||||
cos_q, sin_q = cos, sin
|
||||
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
|
||||
decode_kv_no_split = decode_kv_no_split.unsqueeze(1).unsqueeze(1)
|
||||
decode_k_rope, decode_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||
decode_kv_no_split,
|
||||
self.kv_a_layernorm.weight,
|
||||
cos,
|
||||
sin,
|
||||
decode_slot_mapping.to(torch.int64),
|
||||
value_cache,
|
||||
key_cache,
|
||||
c_kv_scale=None,
|
||||
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||
cache_mode='PA') # adapter NZ
|
||||
# nz_block_size = 16
|
||||
# KVCACHE_NZ_DIM = 16
|
||||
# decode_k_nope = decode_k_nope.view(block_num, 1, self.kv_lora_rank // nz_block_size, block_size, nz_block_size)
|
||||
# decode_k_rope = decode_k_rope.view(block_num, 1, self.qk_rope_head_dim // KVCACHE_NZ_DIM, block_size, KVCACHE_NZ_DIM)
|
||||
|
||||
decode_q_pe = torch_npu.npu_interleave_rope(decode_q_pe, cos,
|
||||
sin) # BNSD
|
||||
|
||||
decode_q_nope = decode_q_nope.view(bsz, self.num_heads,
|
||||
self.kv_lora_rank)
|
||||
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
|
||||
|
||||
topk_indices = self.indexer_select(hidden_states_decode,
|
||||
decode_q_c,
|
||||
attn_metadata=attn_metadata,
|
||||
kv_cache=kv_cache)
|
||||
|
||||
query_states = (decode_q_nope, decode_q_pe)
|
||||
key_states = (decode_k_nope, decode_k_rope)
|
||||
decode_preprocess_res = DecodeSFAPreprocessResult(
|
||||
q_nope=decode_q_nope,
|
||||
q_pe=decode_q_pe,
|
||||
# nope_cache = nope_cache,
|
||||
# rope_cache = rope_cache,
|
||||
topk_indices=topk_indices,
|
||||
query_states=query_states,
|
||||
key_states=key_states,
|
||||
bsz=bsz,
|
||||
)
|
||||
|
||||
# Preprocess for prefill tokens
|
||||
if has_prefill:
|
||||
bsz = 1
|
||||
|
||||
hidden_states_prefill = hidden_states[
|
||||
num_decode_tokens:num_actual_tokens]
|
||||
prefill_kq = self.q_a_proj(hidden_states_prefill) # q down
|
||||
prefill_q_c = self.q_a_layernorm(prefill_kq) # q down layernorm
|
||||
prefill_kv_no_split = self.kv_a_proj_with_mqa(
|
||||
hidden_states_prefill) # c_kv
|
||||
|
||||
# prefill_q_c = q_c[
|
||||
# num_decode_tokens:num_actual_tokens]
|
||||
prefill_slot_mapping = attn_metadata.slot_mapping[
|
||||
num_decode_tokens:num_actual_tokens]
|
||||
# decode_kv_no_split = decode_kv_no_split[:num_decode_tokens]
|
||||
|
||||
prefill_slot_mapping = attn_metadata.slot_mapping[
|
||||
num_decode_tokens:num_actual_tokens]
|
||||
# prefill_kv_no_split = kv_no_split[
|
||||
# num_decode_tokens:num_actual_tokens]
|
||||
# prefill_qr = prefill_q_c[num_decode_tokens:num_actual_tokens]
|
||||
prefill_qr = prefill_q_c
|
||||
prefill_q = self.q_b_proj(prefill_qr)
|
||||
prefill_q = prefill_q.view(-1, self.num_heads, self.qk_head_dim)
|
||||
prefill_q_nope, prefill_q_pe = torch.split(
|
||||
prefill_q, [self.qk_nope_head_dim, self.qk_rope_head_dim],
|
||||
dim=-1)
|
||||
prefill_q_nope = prefill_q_nope.view(
|
||||
-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1)
|
||||
prefill_q_nope = (torch.matmul(prefill_q_nope,
|
||||
self.kv_b_proj_w_k).transpose(
|
||||
1,
|
||||
0).view(-1, self.num_heads,
|
||||
self.kv_lora_rank))
|
||||
prefill_q_pe = prefill_q_pe.unsqueeze(2)
|
||||
|
||||
# stream2 kv
|
||||
|
||||
nope_cache = kv_cache[0]
|
||||
rope_cache = kv_cache[1]
|
||||
cos = attn_metadata.prefill.cos
|
||||
sin = attn_metadata.prefill.sin
|
||||
cos_q, sin_q = cos, sin
|
||||
|
||||
# cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
# sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
|
||||
prefill_q_pe = torch_npu.npu_interleave_rope(
|
||||
prefill_q_pe, cos_q, sin_q) # BNSD
|
||||
prefill_q_pe = prefill_q_pe.squeeze(2) #BSH
|
||||
# q[..., self.qk_nope_head_dim:] = prefill_q_pe # TODO:????
|
||||
|
||||
prefill_latent_cache = prefill_kv_no_split # (B,S,N,D)
|
||||
prefill_k_pe, prefill_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||
prefill_latent_cache.view(
|
||||
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim),
|
||||
self.kv_a_layernorm.weight,
|
||||
cos.view(-1, 1, 1, self.qk_rope_head_dim),
|
||||
sin.view(-1, 1, 1, self.qk_rope_head_dim),
|
||||
prefill_slot_mapping.to(torch.int64),
|
||||
rope_cache,
|
||||
nope_cache,
|
||||
k_rope_scale=None,
|
||||
c_kv_scale=None,
|
||||
k_rope_offset=None,
|
||||
c_kv_offset=None,
|
||||
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||
cache_mode="PA")
|
||||
|
||||
topk_indices = self.indexer_select(x=hidden_states_prefill,
|
||||
qr=prefill_qr,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata)
|
||||
query_states = (prefill_q_nope, prefill_q_pe)
|
||||
key_states = (prefill_k_nope, prefill_k_pe)
|
||||
prefill_preprocess_res = PrefillSFAPreprocessResult(
|
||||
q_nope=prefill_q_nope,
|
||||
q_pe=prefill_q_pe,
|
||||
topk_indices=topk_indices,
|
||||
k_nope=prefill_k_nope,
|
||||
k_pe=prefill_k_pe,
|
||||
query_states=query_states,
|
||||
key_states=key_states,
|
||||
)
|
||||
|
||||
return decode_preprocess_res, prefill_preprocess_res
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor, # query in unified attn
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
attn_metadata: M,
|
||||
need_gather_q_kv: bool = False,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
assert attn_metadata.num_decodes is not None and \
|
||||
attn_metadata.num_prefills is not None and \
|
||||
attn_metadata.num_decode_tokens is not None
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
output = output[:num_actual_tokens, ...]
|
||||
o_proj_input_shape = (num_actual_tokens,
|
||||
self.num_heads * self.v_head_dim)
|
||||
o_proj_input = torch.empty(o_proj_input_shape,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
# SFA Preprocess
|
||||
decode_preprocess_res, prefill_preprocess_res = self._sfa_preprocess(
|
||||
hidden_states, kv_cache, attn_metadata, need_gather_q_kv)
|
||||
|
||||
if decode_preprocess_res is not None:
|
||||
# bsz, q_len, _, _ = query_states[0].shape
|
||||
decode_attn_output = self.apply_attention_fusion(
|
||||
query_states=decode_preprocess_res.query_states,
|
||||
key_states=decode_preprocess_res.key_states,
|
||||
attn_metadata=attn_metadata,
|
||||
topk_indices=decode_preprocess_res.topk_indices)
|
||||
o_proj_input[:num_decode_tokens] = decode_attn_output
|
||||
|
||||
if prefill_preprocess_res is not None:
|
||||
prefill_attn_output = self.apply_attention_fusion(
|
||||
query_states=prefill_preprocess_res.query_states,
|
||||
key_states=prefill_preprocess_res.key_states,
|
||||
attn_metadata=attn_metadata,
|
||||
topk_indices=prefill_preprocess_res.topk_indices)
|
||||
o_proj_input[num_decode_tokens:] = prefill_attn_output
|
||||
|
||||
output[...] = self.mla_epilog(o_proj_input, absorb=True)
|
||||
return output
|
||||
|
||||
def apply_attention_fusion(self, query_states, key_states, topk_indices,
|
||||
attn_metadata: M):
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
q_nope, q_pe = query_states
|
||||
k_nope, k_rope = key_states
|
||||
|
||||
if attn_metadata.prefill is not None:
|
||||
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
|
||||
slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention(
|
||||
query=q_nope,
|
||||
key=k_nope,
|
||||
value=k_nope,
|
||||
sparse_indices=topk_indices,
|
||||
scale_value=self.scale,
|
||||
sparse_block_size=1,
|
||||
block_table=prefill_metadata.block_table,
|
||||
actual_seq_lengths_query=prefill_metadata.query_lens,
|
||||
actual_seq_lengths_kv=prefill_metadata.seq_lens,
|
||||
query_rope=q_pe,
|
||||
key_rope=k_rope,
|
||||
layout_query="TND",
|
||||
layout_kv="PA_BSND",
|
||||
sparse_mode=3,
|
||||
)
|
||||
|
||||
elif attn_metadata.decode is not None:
|
||||
decode_metadata = attn_metadata.decode
|
||||
|
||||
slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention(
|
||||
query=q_nope,
|
||||
key=k_nope,
|
||||
value=k_nope,
|
||||
sparse_indices=topk_indices,
|
||||
scale_value=self.scale,
|
||||
sparse_block_size=1,
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
actual_seq_lengths_query=decode_metadata.actual_seq_lengths_q,
|
||||
actual_seq_lengths_kv=decode_metadata.seq_lens,
|
||||
query_rope=q_pe,
|
||||
key_rope=k_rope,
|
||||
layout_query="TND",
|
||||
layout_kv="PA_BSND",
|
||||
sparse_mode=3,
|
||||
)
|
||||
slc_fa_fusion = slc_fa_fusion.squeeze(1)
|
||||
|
||||
slc_fa_fusion = slc_fa_fusion.transpose(0, 1)
|
||||
|
||||
# input shape [N//attn_tp_size, T(bs*q_len), D]
|
||||
# output shape [T(bs*q_len), N//attn_tp_size, D]
|
||||
attn_output = torch.matmul(slc_fa_fusion,
|
||||
self.kv_b_proj_w_v).transpose(1, 0).reshape(
|
||||
-1, self.num_heads * self.v_head_dim)
|
||||
# Note: Considering the fusion rules of TBMM, attn_output shape requires a 3-dim shape, and
|
||||
# with appropriate tensor stride for the later 'view' operation if oproj_tp_size > 1.
|
||||
# after reshape: [T(bs*q_len), 1, N//attn_tp_size*D]
|
||||
# attn_output = attn_output.reshape(-1, self.num_heads * self.v_head_dim)
|
||||
|
||||
return attn_output
|
||||
|
||||
def mla_epilog(self,
|
||||
attn_output: torch.Tensor = None,
|
||||
absorb: bool = False):
|
||||
# TODO: need to check
|
||||
attn_output = self.o_proj(attn_output.reshape(attn_output.shape[0],
|
||||
-1),
|
||||
is_prefill=True,
|
||||
is_force_scatter=False)
|
||||
|
||||
return attn_output
|
||||
|
||||
def indexer_select(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
qr: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
attn_metadata: M,
|
||||
):
|
||||
if attn_metadata.prefill is not None:
|
||||
cos = attn_metadata.prefill.cos
|
||||
sin = attn_metadata.prefill.sin
|
||||
actual_seq_lengths_query = attn_metadata.prefill.query_lens
|
||||
actual_seq_lengths_key = attn_metadata.prefill.seq_lens
|
||||
block_table = attn_metadata.prefill.block_table
|
||||
elif attn_metadata.decode is not None:
|
||||
cos = attn_metadata.decode.cos
|
||||
sin = attn_metadata.decode.sin
|
||||
actual_seq_lengths_query = attn_metadata.decode.actual_seq_lengths_q
|
||||
actual_seq_lengths_key = attn_metadata.decode.seq_lens
|
||||
block_table = attn_metadata.decode.block_table
|
||||
|
||||
cos_q, sin_q = cos, sin
|
||||
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
|
||||
# q process in new stream
|
||||
q = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
|
||||
q = q.view(-1, self.n_heads, self.head_dim) # [b,s,64,128]
|
||||
q_pe, q_nope = torch.split(
|
||||
q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
|
||||
dim=-1) # [b,s,64,64+64]
|
||||
|
||||
q_pe = q_pe.unsqueeze(2)
|
||||
q_pe = torch_npu.npu_interleave_rope(q_pe, cos_q, sin_q)
|
||||
q_pe = q_pe.squeeze(2)
|
||||
q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128]
|
||||
|
||||
k_proj = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
|
||||
k = self.k_norm(k_proj).unsqueeze(1)
|
||||
k_pe, k_nope = torch.split(
|
||||
k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
|
||||
dim=-1) # [b,s,64+64]
|
||||
|
||||
k_pe = k_pe.unsqueeze(2)
|
||||
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin)
|
||||
k_pe = k_pe.squeeze(2)
|
||||
|
||||
k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128]
|
||||
|
||||
if kv_cache is not None:
|
||||
torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]),
|
||||
attn_metadata.slot_mapping.view(
|
||||
-1, 1),
|
||||
k.view(-1,
|
||||
k.shape[-1])) # b, s, n, d
|
||||
|
||||
weights = self.weights_proj(x)
|
||||
|
||||
topk_indices = torch.ops.custom.npu_lightning_indexer(
|
||||
query=q,
|
||||
key=kv_cache[2],
|
||||
weights=weights,
|
||||
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||
actual_seq_lengths_key=actual_seq_lengths_key,
|
||||
block_table=block_table,
|
||||
layout_query="TND",
|
||||
layout_key="PA_BSND",
|
||||
sparse_count=2048,
|
||||
sparse_mode=3)
|
||||
return topk_indices
|
||||
@@ -1,7 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -21,6 +25,13 @@ class AscendCommonAttentionMetadata:
|
||||
"""(batch_size,), the length of each request including both computed tokens
|
||||
and newly scheduled tokens"""
|
||||
|
||||
seq_lens: torch.Tensor
|
||||
"""same to seq_lens_cpu, for compatibility with some new attn metadata
|
||||
(such as GDN)."""
|
||||
|
||||
num_computed_tokens_cpu: torch.Tensor
|
||||
"""(batch_size,), the number of computed tokens for each request"""
|
||||
|
||||
num_reqs: int
|
||||
"""Number of requests"""
|
||||
num_actual_tokens: int
|
||||
@@ -34,7 +45,7 @@ class AscendCommonAttentionMetadata:
|
||||
|
||||
block_table_tensor: torch.Tensor
|
||||
|
||||
slot_mapping_cpu: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
actual_seq_lengths_q: list[int]
|
||||
|
||||
@@ -93,3 +104,34 @@ def split_decodes_and_prefills(
|
||||
num_decode_tokens = query_start_loc[first_prefill].item()
|
||||
num_prefill_tokens = num_tokens - num_decode_tokens
|
||||
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
|
||||
|
||||
|
||||
def wait_for_kv_layer_from_connector(layer_name: str):
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return
|
||||
|
||||
connector = get_kv_transfer_group()
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
return
|
||||
# TODO: assert ascendMetadata
|
||||
connector.wait_for_layer_load(layer_name)
|
||||
|
||||
|
||||
def maybe_save_kv_layer_to_connector(
|
||||
layer_name: str,
|
||||
kv_cache_layer: List[torch.Tensor],
|
||||
):
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return
|
||||
|
||||
connector = get_kv_transfer_group()
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
return
|
||||
# TODO: assert ascendMetadata
|
||||
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
|
||||
|
||||
@@ -3,10 +3,12 @@
|
||||
|
||||
import dataclasses
|
||||
from contextlib import ExitStack
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.cuda_graph import CUDAGraphOptions
|
||||
@@ -15,7 +17,8 @@ from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import weak_ref_tensors
|
||||
|
||||
from ..utils import weak_ref_tensors
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -35,10 +38,10 @@ class ACLGraphWrapper:
|
||||
|
||||
The workflow of this wrapper in the aclgraph dispatching is as follows:
|
||||
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
|
||||
PIECEWISE).
|
||||
2. At runtime, the wrapper receives a runtime_mode and a
|
||||
PIECEWISE).
|
||||
2. At runtime, the wrapper receives a runtime_mode and a
|
||||
batch_descriptor(key) from the forward context and blindly trust them
|
||||
for aclgraph dispatching.
|
||||
for aclgraph dispatching.
|
||||
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
|
||||
wrapper, just call the runnable directly.
|
||||
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
|
||||
@@ -47,9 +50,9 @@ class ACLGraphWrapper:
|
||||
|
||||
Note: ACLGraphWrapper does not store persistent buffers or copy any
|
||||
runtime inputs into that buffers for replay. We assume implementing them
|
||||
is done outside of the wrapper. That is because we do not make any
|
||||
is done outside of the wrapper. That is because we do not make any
|
||||
assumption on the dynamic shape (batch size) of the runtime inputs, as a
|
||||
trade-off for staying orthogonal to compilation logic. Nevertheless,
|
||||
trade-off for staying orthogonal to compilation logic. Nevertheless,
|
||||
tracing and checking the input addresses to be consistent during replay is
|
||||
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
|
||||
"""
|
||||
@@ -146,6 +149,7 @@ class ACLGraphWrapper:
|
||||
patch("torch.npu.empty_cache", lambda: None))
|
||||
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
forward_context.capturing = True
|
||||
with torch.npu.graph(aclgraph, pool=self.graph_pool):
|
||||
# `output` is managed by pytorch's aclgraph pool
|
||||
output = self.runnable(*args, **kwargs)
|
||||
@@ -183,3 +187,74 @@ class ACLGraphWrapper:
|
||||
logger.info_once("Replaying aclgraph")
|
||||
entry.aclgraph.replay()
|
||||
return entry.output
|
||||
|
||||
|
||||
def update_attn_params(update_stream, forward_context, runtime_shape):
|
||||
graph_params = get_graph_params()
|
||||
# FIXME: Behold! We are using a temporary hack here to update the args
|
||||
# for each layer's attention op in the graph.
|
||||
for key, param, handle, event in zip(
|
||||
forward_context.attn_metadata,
|
||||
graph_params.attn_params[runtime_shape],
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
num_heads,
|
||||
scale,
|
||||
block_table,
|
||||
seq_lens,
|
||||
output,
|
||||
) = param
|
||||
# block_table = forward_context.attn_metadata[key].block_tables
|
||||
seq_lens = forward_context.attn_metadata[key].seq_lens
|
||||
|
||||
with torch.npu.stream(update_stream):
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
torch_npu._npu_paged_attention(query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output)
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphParams:
|
||||
events: dict[int, list[torch.npu.ExternalEvent]]
|
||||
workspaces: dict[int, torch.Tensor]
|
||||
handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]]
|
||||
attn_params: dict[int, list[tuple]]
|
||||
|
||||
|
||||
_graph_params: Optional[GraphParams] = None
|
||||
|
||||
|
||||
def set_graph_params(aclgraph_capture_sizes: set[int]):
|
||||
global _graph_params
|
||||
if _graph_params is not None:
|
||||
raise ValueError("Graph parameters have already been set!")
|
||||
_graph_params = GraphParams(
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: None
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
)
|
||||
|
||||
|
||||
def get_graph_params():
|
||||
return _graph_params
|
||||
|
||||
@@ -20,14 +20,19 @@ from typing import Type, Union
|
||||
|
||||
from vllm.config import SchedulerConfig
|
||||
|
||||
MAX_INT = 2147483647
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendSchedulerConfig(SchedulerConfig):
|
||||
enable_chunked_prefill: bool = False
|
||||
max_long_partial_prefills: int = MAX_INT
|
||||
long_prefill_token_threshold: int = MAX_INT
|
||||
policy: str = "fcfs"
|
||||
num_scheduler_steps: int = 1
|
||||
scheduler_cls: Union[str, Type[object]] = (
|
||||
"vllm_ascend.core.scheduler.AscendScheduler")
|
||||
enable_pd_transfer: bool = False
|
||||
decode_max_num_seqs: int = 0
|
||||
|
||||
@classmethod
|
||||
def initialize_from_config(
|
||||
@@ -41,10 +46,13 @@ class AscendSchedulerConfig(SchedulerConfig):
|
||||
}
|
||||
# Override default values into original SchedulerConfig
|
||||
scheduler_config["enable_chunked_prefill"] = False
|
||||
scheduler_config["max_long_partial_prefills"] = None
|
||||
scheduler_config["long_prefill_token_threshold"] = None
|
||||
scheduler_config["policy"] = "fcfs"
|
||||
scheduler_config["num_scheduler_steps"] = 1
|
||||
scheduler_config["scheduler_cls"] = (
|
||||
"vllm_ascend.core.scheduler.AscendScheduler")
|
||||
scheduler_config["enable_pd_transfer"] = False
|
||||
scheduler_config["decode_max_num_seqs"] = 0
|
||||
# Override params in original SchedulerConfig with params in ascend_scheduler_config
|
||||
for k, _ in scheduler_config.items():
|
||||
if hasattr(ascend_scheduler_config, k):
|
||||
@@ -65,20 +73,36 @@ class AscendSchedulerConfig(SchedulerConfig):
|
||||
"max_num_batched_tokens and makes vLLM reject longer "
|
||||
"sequences. Please increase max_num_batched_tokens or "
|
||||
"decrease max_model_len.")
|
||||
# concurrent partial prefills. Default is inf
|
||||
if self.max_long_partial_prefills is None:
|
||||
self.max_long_partial_prefills = MAX_INT
|
||||
self.long_prefill_token_threshold = MAX_INT
|
||||
|
||||
if self.long_prefill_token_threshold is None or \
|
||||
self.long_prefill_token_threshold <= 0:
|
||||
if self.max_model_len is None:
|
||||
self.long_prefill_token_threshold = MAX_INT
|
||||
else:
|
||||
self.long_prefill_token_threshold = \
|
||||
max(1, int(self.max_model_len * 0.04))
|
||||
|
||||
if self.max_long_partial_prefills < 0:
|
||||
raise ValueError(
|
||||
f"max_long_partial_prefills must be non-negative, but got "
|
||||
f"{self.max_long_partial_prefills}")
|
||||
if self.long_prefill_token_threshold < 0:
|
||||
raise ValueError(
|
||||
f"long_prefill_token_threshold must be non-negative, but got "
|
||||
f"{self.long_prefill_token_threshold}")
|
||||
|
||||
if self.policy != "fcfs":
|
||||
raise NotImplementedError(
|
||||
f"currently AscendScheduler only supports fcfs policy, got {self.policy}"
|
||||
)
|
||||
if self.is_multimodal_model:
|
||||
raise NotImplementedError(
|
||||
"currently AscendScheduler only supports LLM models.")
|
||||
if self.num_scheduler_steps > 1:
|
||||
raise NotImplementedError(
|
||||
"currently AscendScheduler doesn't support multi-step.")
|
||||
if self.send_delta_data:
|
||||
raise NotImplementedError(
|
||||
"currently AscendScheduler doesn't support send_delta_data.")
|
||||
if self.delay_factor > 0:
|
||||
if getattr(self, "scheduler_delay_factor", 0) > 0:
|
||||
raise NotImplementedError(
|
||||
"currently AscendScheduler doesn't support scheduler_delay_factor."
|
||||
)
|
||||
|
||||
@@ -23,6 +23,7 @@ from vllm.distributed.kv_events import KVEventBatch
|
||||
from vllm.logger import logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutputs
|
||||
@@ -31,13 +32,6 @@ from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
else:
|
||||
KVCacheBlocks = None
|
||||
|
||||
|
||||
class AscendScheduler(Scheduler):
|
||||
"""This Scheduler extends vllm's original v1 scheduler
|
||||
@@ -58,6 +52,15 @@ class AscendScheduler(Scheduler):
|
||||
self.scheduled_req_ids: set[str] = set()
|
||||
self.running: list[Request] = []
|
||||
|
||||
self.finished_prefill_reqs: deque[Request] = deque()
|
||||
enable_pd_transfer = getattr(self.scheduler_config,
|
||||
'enable_pd_transfer', False)
|
||||
decode_max_num_seqs = getattr(self.scheduler_config,
|
||||
'decode_max_num_seqs', 0)
|
||||
self.phase = "" if not enable_pd_transfer else "prefill"
|
||||
self.decode_max_num_running_reqs = max(self.max_num_running_reqs,
|
||||
decode_max_num_seqs)
|
||||
|
||||
def schedule(self) -> SchedulerOutput:
|
||||
if self.scheduler_config.chunked_prefill_enabled:
|
||||
return super().schedule()
|
||||
@@ -66,12 +69,14 @@ class AscendScheduler(Scheduler):
|
||||
scheduled_running_reqs: list[Request] = []
|
||||
preempted_reqs: list[Request] = []
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
req_to_new_block_ids: dict[str, list[list[int]]] = {}
|
||||
else:
|
||||
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
|
||||
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
|
||||
num_scheduled_tokens: dict[str, int] = {}
|
||||
token_budget = self.max_num_scheduled_tokens
|
||||
|
||||
# Encoder-related.
|
||||
scheduled_encoder_inputs: dict[str, list[int]] = {}
|
||||
encoder_budget = self.max_num_encoder_input_tokens
|
||||
|
||||
# Spec decode-related.
|
||||
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
|
||||
|
||||
@@ -85,9 +90,33 @@ class AscendScheduler(Scheduler):
|
||||
# and put back at the head of the waiting queue later
|
||||
skipped_waiting_requests: deque[Request] = deque()
|
||||
|
||||
if self.phase == "prefill":
|
||||
remaining_running_reqs = []
|
||||
for request in self.running:
|
||||
# move request has finished prefill to finished_prefill_reqs
|
||||
if request.num_tokens > request.num_prompt_tokens:
|
||||
self.finished_prefill_reqs.append(request)
|
||||
else:
|
||||
remaining_running_reqs.append(request)
|
||||
self.running = remaining_running_reqs
|
||||
# all request prefilled, change phase to decode
|
||||
if not self.waiting and not self.running:
|
||||
self.phase = "decode"
|
||||
# Skip long prompt requests in prefill stage.
|
||||
# long_prefill_budget is float('inf') if not use.
|
||||
if self.vllm_config.scheduler_config.long_prefill_token_threshold == 0:
|
||||
long_prefill_budget = float('inf')
|
||||
long_prefill_token_threshold = float('inf')
|
||||
else:
|
||||
long_prefill_budget = self.vllm_config.scheduler_config.max_long_partial_prefills
|
||||
long_prefill_token_threshold = self.vllm_config.scheduler_config.long_prefill_token_threshold
|
||||
|
||||
# Schedule prefill requests first.
|
||||
while self.waiting and token_budget > 0:
|
||||
if len(self.running) == self.max_num_running_reqs:
|
||||
if len(self.running) == (self.decode_max_num_running_reqs
|
||||
if self.phase == "decode" else
|
||||
self.max_num_running_reqs):
|
||||
|
||||
break
|
||||
|
||||
request = self.waiting[0]
|
||||
@@ -139,6 +168,9 @@ class AscendScheduler(Scheduler):
|
||||
num_new_local_computed_tokens = 0
|
||||
num_computed_tokens = request.num_computed_tokens
|
||||
|
||||
encoder_inputs_to_schedule = None
|
||||
new_encoder_budget = encoder_budget
|
||||
|
||||
# P/D: loading remote KV, do not allocate for new work.
|
||||
if load_kv_async:
|
||||
assert num_external_computed_tokens > 0
|
||||
@@ -176,6 +208,17 @@ class AscendScheduler(Scheduler):
|
||||
assert num_new_tokens > 0
|
||||
blocks = new_computed_blocks.blocks[0]
|
||||
|
||||
# Schedule encoder inputs.
|
||||
if request.has_encoder_inputs:
|
||||
(encoder_inputs_to_schedule, num_new_tokens,
|
||||
new_encoder_budget) = self._try_schedule_encoder_inputs(
|
||||
request, num_computed_tokens, num_new_tokens,
|
||||
encoder_budget)
|
||||
if num_new_tokens == 0 or len(
|
||||
encoder_inputs_to_schedule) == 0:
|
||||
# The request cannot be scheduled.
|
||||
break
|
||||
|
||||
watermark = getattr(self.scheduler_config, "watermark", 0.01)
|
||||
if not self._check_watermark_for_prefill(request, num_new_tokens,
|
||||
blocks, watermark):
|
||||
@@ -183,6 +226,11 @@ class AscendScheduler(Scheduler):
|
||||
skip_cur_request()
|
||||
continue
|
||||
|
||||
if num_new_tokens > long_prefill_token_threshold \
|
||||
and long_prefill_budget <= 0:
|
||||
skip_cur_request()
|
||||
continue
|
||||
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens + num_external_computed_tokens,
|
||||
@@ -227,26 +275,41 @@ class AscendScheduler(Scheduler):
|
||||
|
||||
if self.lora_config and request.lora_request:
|
||||
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
req_to_new_block_ids[request.request_id] = (
|
||||
self.kv_cache_manager.get_block_ids(request.request_id))
|
||||
else:
|
||||
req_to_new_blocks[
|
||||
request.request_id] = self.kv_cache_manager.get_blocks(
|
||||
request.request_id)
|
||||
|
||||
req_to_new_blocks[
|
||||
request.request_id] = self.kv_cache_manager.get_blocks(
|
||||
request.request_id)
|
||||
# Update request info.
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
if num_new_tokens > long_prefill_token_threshold:
|
||||
long_prefill_budget -= 1
|
||||
request.status = RequestStatus.RUNNING
|
||||
request.num_computed_tokens = num_computed_tokens
|
||||
# Count the number of prefix cached tokens.
|
||||
if request.num_cached_tokens < 0:
|
||||
request.num_cached_tokens = num_computed_tokens
|
||||
|
||||
# Encoder-related.
|
||||
if encoder_inputs_to_schedule:
|
||||
scheduled_encoder_inputs[request.request_id] = (
|
||||
encoder_inputs_to_schedule)
|
||||
# Allocate the encoder cache.
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
encoder_budget = new_encoder_budget
|
||||
|
||||
# Put back any skipped requests at the head of the waiting queue
|
||||
if skipped_waiting_requests:
|
||||
self.waiting.extendleft(skipped_waiting_requests)
|
||||
|
||||
if self.phase == "decode":
|
||||
while len(
|
||||
self.running
|
||||
) < self.decode_max_num_running_reqs and self.finished_prefill_reqs:
|
||||
request = self.finished_prefill_reqs.popleft()
|
||||
self.running.append(request)
|
||||
|
||||
# If no prefill requests are scheduled,
|
||||
# Schedule decode requests next.
|
||||
if len(self.scheduled_req_ids) == 0:
|
||||
@@ -267,6 +330,16 @@ class AscendScheduler(Scheduler):
|
||||
num_new_tokens = min(
|
||||
num_new_tokens,
|
||||
self.max_model_len - request.num_computed_tokens)
|
||||
|
||||
# Schedule encoder inputs.
|
||||
encoder_inputs_to_schedule = None
|
||||
new_encoder_budget = encoder_budget
|
||||
if request.has_encoder_inputs:
|
||||
(encoder_inputs_to_schedule, num_new_tokens,
|
||||
new_encoder_budget) = self._try_schedule_encoder_inputs(
|
||||
request, request.num_computed_tokens, num_new_tokens,
|
||||
encoder_budget)
|
||||
|
||||
# Check that adding the request still respects the max_loras
|
||||
# constraint.
|
||||
if self.lora_config and request.lora_request and (
|
||||
@@ -322,11 +395,7 @@ class AscendScheduler(Scheduler):
|
||||
# Schedule the request.
|
||||
scheduled_running_reqs.append(request)
|
||||
self.scheduled_req_ids.add(request.request_id)
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
req_to_new_block_ids[request.request_id] = (
|
||||
new_blocks.get_block_ids())
|
||||
else:
|
||||
req_to_new_blocks[request.request_id] = new_blocks
|
||||
req_to_new_blocks[request.request_id] = new_blocks
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
req_index += 1
|
||||
@@ -342,6 +411,15 @@ class AscendScheduler(Scheduler):
|
||||
scheduled_spec_decode_tokens[request.request_id] = (
|
||||
request.spec_token_ids)
|
||||
|
||||
# Encoder-related.
|
||||
if encoder_inputs_to_schedule:
|
||||
scheduled_encoder_inputs[request.request_id] = (
|
||||
encoder_inputs_to_schedule)
|
||||
# Allocate the encoder cache.
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
encoder_budget = new_encoder_budget
|
||||
|
||||
# Record scheduled LoRA requests.
|
||||
if self.lora_config and request.lora_request:
|
||||
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||
@@ -350,7 +428,9 @@ class AscendScheduler(Scheduler):
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
||||
assert token_budget >= 0
|
||||
assert len(self.running) <= self.max_num_running_reqs
|
||||
assert len(
|
||||
self.running
|
||||
) <= self.decode_max_num_running_reqs if self.phase == "decode" else self.max_num_running_reqs
|
||||
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(
|
||||
scheduled_running_reqs) <= len(self.running)
|
||||
|
||||
@@ -365,67 +445,36 @@ class AscendScheduler(Scheduler):
|
||||
any_request, len(self.running)))
|
||||
|
||||
# Construct the scheduler output.
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(
|
||||
req, req_to_new_block_ids[req.request_id])
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
cached_reqs_data = self._make_cached_request_data(
|
||||
scheduled_running_reqs, scheduled_resumed_reqs,
|
||||
num_scheduled_tokens, scheduled_spec_decode_tokens,
|
||||
req_to_new_block_ids)
|
||||
else:
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(
|
||||
req, req_to_new_blocks[req.request_id].get_block_ids())
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(
|
||||
req, req_to_new_blocks[req.request_id].get_block_ids())
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
|
||||
cached_reqs_data = self._make_cached_request_data(
|
||||
scheduled_running_reqs, scheduled_resumed_reqs,
|
||||
num_scheduled_tokens, scheduled_spec_decode_tokens,
|
||||
req_to_new_blocks)
|
||||
cached_reqs_data = self._make_cached_request_data(
|
||||
scheduled_running_reqs, scheduled_resumed_reqs,
|
||||
num_scheduled_tokens, scheduled_spec_decode_tokens,
|
||||
req_to_new_blocks)
|
||||
scheduled_cached_reqs = cached_reqs_data
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=new_reqs_data,
|
||||
scheduled_cached_reqs=scheduled_cached_reqs,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||
# finished_req_ids is an existing state in the scheduler,
|
||||
# instead of being newly scheduled in this step.
|
||||
# It contains the request IDs that are finished in between
|
||||
# the previous and the current steps.
|
||||
finished_req_ids=self.finished_req_ids, # type: ignore
|
||||
free_encoder_input_ids=self.encoder_cache_manager.
|
||||
get_freed_ids(),
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
else:
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=new_reqs_data,
|
||||
scheduled_cached_reqs=scheduled_cached_reqs,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||
# finished_req_ids is an existing state in the scheduler,
|
||||
# instead of being newly scheduled in this step.
|
||||
# It contains the request IDs that are finished in between
|
||||
# the previous and the current steps.
|
||||
finished_req_ids=self.finished_req_ids, # type: ignore
|
||||
free_encoder_mm_hashes=self.encoder_cache_manager.
|
||||
get_freed_mm_hashes(),
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=new_reqs_data,
|
||||
scheduled_cached_reqs=scheduled_cached_reqs,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
||||
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||
# finished_req_ids is an existing state in the scheduler,
|
||||
# instead of being newly scheduled in this step.
|
||||
# It contains the request IDs that are finished in between
|
||||
# the previous and the current steps.
|
||||
finished_req_ids=self.finished_req_ids, # type: ignore
|
||||
free_encoder_mm_hashes=self.encoder_cache_manager.
|
||||
get_freed_mm_hashes(),
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
# NOTE(Kuntai): this function is designed for multiple purposes:
|
||||
# 1. Plan the KV cache store
|
||||
|
||||
@@ -26,3 +26,8 @@ KVConnectorFactory.register_connector(
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeConnectorV1", "vllm_ascend.distributed.mooncake_connector",
|
||||
"MooncakeConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeConnectorStoreV1",
|
||||
"vllm_ascend.distributed.mooncake.mooncake_store_connector_v1",
|
||||
"MooncakeConnectorV1")
|
||||
|
||||
457
vllm_ascend/distributed/cpu_offload_connector.py
Normal file
457
vllm_ascend/distributed/cpu_offload_connector.py
Normal file
@@ -0,0 +1,457 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.utils import logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec
|
||||
|
||||
from vllm_ascend.distributed.cpu_offload_manager.metadata import (
|
||||
MetadataServer, MetadataServerProc, MLAConfig)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
gpu_block_ids: list[int]
|
||||
cpu_block_ids: list[int]
|
||||
num_scheduled_tokens: int
|
||||
num_computed_tokens: int
|
||||
num_gpu_computed_tokens: int
|
||||
num_cpu_computed_tokens: int
|
||||
|
||||
def update(self, other: "ReqMeta"):
|
||||
self.gpu_block_ids.extend(other.gpu_block_ids)
|
||||
self.cpu_block_ids.extend(other.cpu_block_ids)
|
||||
self.num_scheduled_tokens = other.num_scheduled_tokens
|
||||
self.num_computed_tokens = other.num_computed_tokens
|
||||
self.num_gpu_computed_tokens = other.num_gpu_computed_tokens
|
||||
self.num_cpu_computed_tokens = other.num_cpu_computed_tokens
|
||||
|
||||
|
||||
@dataclass
|
||||
class CPUOffloadingConnectorMetadata(KVConnectorMetadata):
|
||||
requests: dict[str, ReqMeta]
|
||||
finished_req_ids: set[str]
|
||||
|
||||
|
||||
class CPUOffloadingConnector(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
if not vllm_config.cache_config.enable_prefix_caching:
|
||||
self.connector_scheduler: Optional[
|
||||
CPUOffloadingConnectorScheduler] = None
|
||||
self.connector_worker: Optional[
|
||||
CPUOffloadingConnectorWorker] = None
|
||||
elif role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler = CPUOffloadingConnectorScheduler(
|
||||
vllm_config)
|
||||
self.connector_worker = None
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_scheduler = None
|
||||
self.connector_worker = CPUOffloadingConnectorWorker(vllm_config)
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def bind_connector_metadata(
|
||||
self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
if self.connector_worker is not None:
|
||||
assert isinstance(connector_metadata,
|
||||
CPUOffloadingConnectorMetadata)
|
||||
self.connector_worker.bind_connector_metadata(connector_metadata)
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.clear_connector_metadata()
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
if self.connector_worker is not None:
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
if self.connector_worker is not None:
|
||||
self.connector_worker.start_load_kv()
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
if self.connector_worker is not None:
|
||||
self.connector_worker.wait_for_layer_load()
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
pass
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_finished(), None
|
||||
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
if self.connector_scheduler is not None:
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
if self.connector_scheduler is not None:
|
||||
return self.connector_scheduler.update_state_after_alloc(request)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
if self.connector_scheduler is not None:
|
||||
return self.connector_scheduler.build_connector_meta(
|
||||
scheduler_output)
|
||||
return KVConnectorMetadata()
|
||||
|
||||
def request_finished(
|
||||
self, request: "Request",
|
||||
block_ids: list[int]) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
if self.connector_scheduler is not None:
|
||||
self.connector_scheduler.request_finished(request)
|
||||
return True, None
|
||||
|
||||
|
||||
class CPUOffloadingConnectorScheduler:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
logger.info("init CPUOffloadingConnectorScheduler")
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.use_mla = vllm_config.model_config.use_mla
|
||||
self.num_gpu_computed_tokens: dict[str, int] = {}
|
||||
self.num_cpu_computed_tokens: dict[str, int] = {}
|
||||
self.allocated_req_ids: set[str] = set()
|
||||
self.finished_req_ids: list[str] = []
|
||||
self.zmq_rpc_client = MetadataServer.ZMQRPCClient()
|
||||
self.zmq_rpc_client.call("post_init")
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
self.swap_in_threshold = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"swap_in_threshold", 0)
|
||||
else:
|
||||
self.swap_in_threshold = 0
|
||||
logger.info(f"swap_in_threshold: {self.swap_in_threshold}")
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, ori_request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
request = copy.deepcopy(ori_request)
|
||||
request.get_hash_new_full_blocks = None
|
||||
num_cpu_computed_tokens, load_async = self.zmq_rpc_client.call(
|
||||
"get_matched_num_and_touch", request)
|
||||
self.num_gpu_computed_tokens[request.request_id] = num_computed_tokens
|
||||
self.num_cpu_computed_tokens[
|
||||
request.request_id] = num_cpu_computed_tokens
|
||||
if num_cpu_computed_tokens - num_computed_tokens >= self.swap_in_threshold:
|
||||
return num_cpu_computed_tokens - num_computed_tokens, load_async
|
||||
else:
|
||||
return 0, load_async
|
||||
|
||||
def update_state_after_alloc(self, request: "Request"):
|
||||
self.allocated_req_ids.add(request.request_id)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
num_tokens = {}
|
||||
# process scheduled_new_reqs
|
||||
for req in scheduler_output.scheduled_new_reqs:
|
||||
req_id = req.req_id
|
||||
num_tokens[req_id] = (
|
||||
req.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
|
||||
# process scheduled_cached_reqs
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for idx, req_id in enumerate(cached_reqs.req_ids):
|
||||
num_tokens[req_id] = (
|
||||
cached_reqs.num_computed_tokens[idx] +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
|
||||
unallocated_req_ids = set(self.num_gpu_computed_tokens.keys() -
|
||||
self.allocated_req_ids -
|
||||
scheduler_output.num_scheduled_tokens.keys())
|
||||
new_cpu_block_ids = self.zmq_rpc_client.call("allocate_slots",
|
||||
num_tokens,
|
||||
unallocated_req_ids)
|
||||
metadata = CPUOffloadingConnectorMetadata(
|
||||
requests={},
|
||||
finished_req_ids=set(self.finished_req_ids),
|
||||
)
|
||||
for req in scheduler_output.scheduled_new_reqs:
|
||||
req_id = req.req_id
|
||||
gpu_block_ids = req.block_ids[0]
|
||||
metadata.requests[req_id] = ReqMeta(
|
||||
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
|
||||
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
|
||||
num_scheduled_tokens=scheduler_output.
|
||||
num_scheduled_tokens[req_id],
|
||||
num_computed_tokens=req.num_computed_tokens,
|
||||
num_gpu_computed_tokens=self.num_gpu_computed_tokens[req_id],
|
||||
num_cpu_computed_tokens=self.num_cpu_computed_tokens[req_id])
|
||||
|
||||
for idx, req_id in enumerate(cached_reqs.req_ids):
|
||||
gpu_block_ids = cached_reqs.new_block_ids[idx]
|
||||
metadata.requests[req_id] = ReqMeta(
|
||||
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
|
||||
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
|
||||
num_scheduled_tokens=scheduler_output.
|
||||
num_scheduled_tokens[req_id],
|
||||
num_computed_tokens=cached_reqs.num_computed_tokens[idx],
|
||||
num_gpu_computed_tokens=cached_reqs.num_computed_tokens[idx],
|
||||
num_cpu_computed_tokens=cached_reqs.num_computed_tokens[idx])
|
||||
self.num_gpu_computed_tokens.clear()
|
||||
self.num_cpu_computed_tokens.clear()
|
||||
self.allocated_req_ids.clear()
|
||||
self.finished_req_ids.clear()
|
||||
return metadata
|
||||
|
||||
def request_finished(self, ori_request: "Request"):
|
||||
request = copy.deepcopy(ori_request)
|
||||
request.get_hash_new_full_blocks = None
|
||||
self.finished_req_ids.append(request.request_id)
|
||||
# inform metadata server to record request, and free it after finish sending
|
||||
self.zmq_rpc_client.call("record_request_cache_and_free_slots",
|
||||
request)
|
||||
|
||||
|
||||
class CPUOffloadingConnectorWorker:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
logger.info("init CPUOffloadingConnectorWorker")
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.pp_rank = get_pp_group().rank_in_group
|
||||
self.tp_group = get_tp_group()
|
||||
self.tp_rank = self.tp_group.rank_in_group
|
||||
self.tp_world_size = self.tp_group.world_size
|
||||
self.use_mla = vllm_config.model_config.use_mla
|
||||
|
||||
self.requests: dict[str, ReqMeta] = {}
|
||||
self.load_stream = torch.npu.Stream()
|
||||
self.save_stream = torch.npu.Stream()
|
||||
self.zmq_rpc_client = MetadataServer.ZMQRPCClient()
|
||||
self.load_block_mapping: list[tuple[int, int]] = []
|
||||
self.save_input_queue: queue.Queue[tuple[str, ReqMeta]] = queue.Queue()
|
||||
self.save_output_queue: queue.Queue[str] = queue.Queue()
|
||||
self.save_thread = threading.Thread(target=self._save_listener)
|
||||
self.save_thread.start()
|
||||
self.done_sending_count: defaultdict[str, int] = defaultdict(int)
|
||||
|
||||
# start metadata server to init cpu_kv_cache_manager and handle rpc requests
|
||||
# all dp shared the same metadata server, only start the process on data_rank 0
|
||||
if vllm_config.parallel_config.data_parallel_rank == 0 and self.tp_rank == 0 and self.pp_rank == 0:
|
||||
config = VllmConfig()
|
||||
config.cache_config = vllm_config.cache_config
|
||||
config.parallel_config = vllm_config.parallel_config
|
||||
config.kv_transfer_config = vllm_config.kv_transfer_config
|
||||
self.init_metadata_server(config)
|
||||
self._wait_for_metadata_process_start()
|
||||
|
||||
def init_metadata_server(self, vllm_config: VllmConfig):
|
||||
self.metadata_thread = threading.Thread(
|
||||
target=MetadataServerProc.run_metadata_server,
|
||||
args=(vllm_config, ),
|
||||
)
|
||||
self.metadata_thread.daemon = True
|
||||
self.metadata_thread.start()
|
||||
|
||||
def _wait_for_metadata_process_start(self):
|
||||
# TODO: wait for metadata server to start, add a rpc to check if ready
|
||||
while True:
|
||||
try:
|
||||
if self.zmq_rpc_client.call("ready"):
|
||||
break
|
||||
except Exception as e:
|
||||
logger.info(f"wait for metadata server to start, error: {e}")
|
||||
time.sleep(1)
|
||||
|
||||
def bind_connector_metadata(
|
||||
self, connector_metadata: CPUOffloadingConnectorMetadata) -> None:
|
||||
for req_id, req in connector_metadata.requests.items():
|
||||
if req_id in self.requests:
|
||||
self.requests[req_id].update(req)
|
||||
req = self.requests[req_id]
|
||||
else:
|
||||
self.requests[req_id] = req
|
||||
for i in range(req.num_gpu_computed_tokens // self.block_size,
|
||||
req.num_computed_tokens // self.block_size):
|
||||
self.load_block_mapping.append(
|
||||
(req.cpu_block_ids[i], req.gpu_block_ids[i]))
|
||||
for req_id in connector_metadata.finished_req_ids:
|
||||
if req_id in self.requests:
|
||||
self.save_input_queue.put((req_id, self.requests[req_id]))
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
self.load_block_mapping.clear()
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, Sequence[torch.Tensor]]):
|
||||
self.gpu_kv_caches = kv_caches
|
||||
model_config = self.vllm_config.model_config
|
||||
mla_config: Optional[MLAConfig] = None
|
||||
if model_config.use_mla:
|
||||
mla_config = MLAConfig(
|
||||
model_config.hf_text_config.kv_lora_rank,
|
||||
model_config.hf_text_config.qk_rope_head_dim)
|
||||
self.cpu_kv_caches = list(
|
||||
self.zmq_rpc_client.call(
|
||||
"init_cpu_kv_caches",
|
||||
self.pp_rank,
|
||||
self.tp_rank,
|
||||
get_kv_cache_spec(self.vllm_config),
|
||||
mla_config,
|
||||
).values())
|
||||
|
||||
def start_load_kv(self) -> None:
|
||||
self.current_layer = 0
|
||||
self.gpu_kv_caches_load_iter = iter(self.gpu_kv_caches.values())
|
||||
self.load_kv_layer(0)
|
||||
|
||||
def wait_for_layer_load(self) -> None:
|
||||
# TODO: Replace with `torch.npu.current_stream().wait_stream(self.load_stream)` after fixing the bug.
|
||||
self.load_stream.synchronize()
|
||||
self.current_layer += 1
|
||||
self.load_kv_layer(self.current_layer)
|
||||
|
||||
def load_kv_layer(self, layer: int):
|
||||
if layer == len(self.gpu_kv_caches):
|
||||
return
|
||||
gpu_kv_caches = next(self.gpu_kv_caches_load_iter)
|
||||
cpu_kv_caches = self.cpu_kv_caches[layer]
|
||||
with torch.npu.stream(self.load_stream):
|
||||
for cpu_block_id, gpu_block_id in self.load_block_mapping:
|
||||
for gpu_layer_part, cpu_layer_part in zip(
|
||||
gpu_kv_caches, cpu_kv_caches):
|
||||
gpu_layer_part[gpu_block_id].copy_(
|
||||
cpu_layer_part[cpu_block_id], non_blocking=True)
|
||||
|
||||
def get_finished(self) -> set[str]:
|
||||
done_sending: set[str] = set()
|
||||
while True:
|
||||
try:
|
||||
id = self.save_output_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
done_sending.add(id)
|
||||
for id in done_sending:
|
||||
del self.requests[id]
|
||||
if self.tp_world_size == 1:
|
||||
return done_sending
|
||||
if self.tp_rank == 0:
|
||||
for req_id in done_sending:
|
||||
self.done_sending_count[req_id] += 1
|
||||
other_ranks_finished_ids: list[str] = []
|
||||
for i in range(1, self.tp_world_size):
|
||||
other_ranks_finished_ids.extend(
|
||||
self.tp_group.recv_object(src=i))
|
||||
for req_id in other_ranks_finished_ids:
|
||||
self.done_sending_count[req_id] += 1
|
||||
all_done_sending: set[str] = set()
|
||||
for req_id in list(self.done_sending_count.keys()):
|
||||
if self.done_sending_count[req_id] == self.tp_world_size:
|
||||
del self.done_sending_count[req_id]
|
||||
all_done_sending.add(req_id)
|
||||
# release cpu_kv_cache after request sending finished
|
||||
# to avoid rpc blocking, use thread to call rpc asynchronously
|
||||
sending_finished_thread = threading.Thread(
|
||||
target=self._sending_finished, args=(all_done_sending, ))
|
||||
sending_finished_thread.daemon = True
|
||||
sending_finished_thread.start()
|
||||
|
||||
return all_done_sending
|
||||
else:
|
||||
self.tp_group.send_object(done_sending, dst=0)
|
||||
return done_sending
|
||||
|
||||
def _sending_finished(self, all_done_sending):
|
||||
for req_id in all_done_sending:
|
||||
logger.debug(f"call cache_and_free_slots for req_id: {req_id}")
|
||||
self.zmq_rpc_client.call("cache_and_free_slots", req_id)
|
||||
|
||||
def _save_listener(self):
|
||||
save_block_mapping = []
|
||||
while True:
|
||||
req_id, req = self.save_input_queue.get()
|
||||
for i in range(
|
||||
req.num_cpu_computed_tokens // self.block_size,
|
||||
min((req.num_computed_tokens + req.num_scheduled_tokens) //
|
||||
self.block_size, len(req.cpu_block_ids))):
|
||||
save_block_mapping.append(
|
||||
(req.gpu_block_ids[i], req.cpu_block_ids[i]))
|
||||
with torch.npu.stream(self.save_stream):
|
||||
# MLA: kv_layer is tuple[tensor, tensor] means (rope, nope).
|
||||
# non-MLA: kv_layer is list[tensor], typically means [k, v].
|
||||
if self.use_mla:
|
||||
start, step = self.tp_rank, self.tp_world_size
|
||||
else:
|
||||
start, step = 0, 1
|
||||
for i in range(start, len(save_block_mapping), step):
|
||||
gpu_block_id, cpu_block_id = save_block_mapping[i]
|
||||
for cpu_kv_caches, gpu_kv_caches in zip(
|
||||
self.cpu_kv_caches, self.gpu_kv_caches.values()):
|
||||
for cpu_layer_part, gpu_layer_part in zip(
|
||||
cpu_kv_caches, gpu_kv_caches):
|
||||
cpu_layer_part[cpu_block_id].copy_(
|
||||
gpu_layer_part[gpu_block_id],
|
||||
non_blocking=True)
|
||||
self.save_stream.synchronize()
|
||||
self.save_output_queue.put(req_id)
|
||||
save_block_mapping.clear()
|
||||
|
||||
|
||||
# Copied from vllm_ascend/worker/model_runner_v1.py.
|
||||
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
|
||||
forward_ctx = vllm_config.compilation_config.static_forward_context
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
use_mla = vllm_config.model_config.use_mla
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
for layer_name, attn_module in forward_ctx.items():
|
||||
if isinstance(attn_module, FusedMoE):
|
||||
continue
|
||||
assert isinstance(attn_module, Attention)
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=attn_module.dtype,
|
||||
use_mla=use_mla)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
continue
|
||||
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown attention type: {attn_module.attn_type}")
|
||||
return kv_cache_spec
|
||||
@@ -0,0 +1,202 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
from vllm.utils import logger, sha256
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
||||
PrefixCachingMetrics)
|
||||
from vllm.v1.core.single_type_kv_cache_manager import \
|
||||
get_manager_for_kv_cache_spec
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
class CPUCacheStats:
|
||||
|
||||
def __init__(self, enable_prefix_caching: bool, log_stats: bool = False):
|
||||
self.enable_prefix_caching = enable_prefix_caching
|
||||
self.log_stats = log_stats
|
||||
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
|
||||
self.cpu_prefix_cache_metrics = PrefixCachingMetrics()
|
||||
self.time_sec = int(time.time())
|
||||
|
||||
def log(self):
|
||||
current_time_sec = int(time.time())
|
||||
# Log the prefix cache hit rate every 10 seconds.
|
||||
if current_time_sec - self.time_sec >= 10:
|
||||
self.time_sec = current_time_sec
|
||||
logger.info("CPU Prefix cache hit rate: %.1f%%",
|
||||
self.cpu_prefix_cache_metrics.hit_rate * 100)
|
||||
|
||||
def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]:
|
||||
"""Get (and reset) the prefix cache stats.
|
||||
Returns:
|
||||
The current prefix caching stats, or None if logging is disabled.
|
||||
"""
|
||||
if not self.log_stats:
|
||||
return None
|
||||
stats = self.prefix_cache_stats
|
||||
self.prefix_cache_stats = PrefixCacheStats()
|
||||
return stats
|
||||
|
||||
def update(self, num_tokens, num_computed_tokens):
|
||||
# Note the function is called by scheduler
|
||||
if self.log_stats and self.enable_prefix_caching:
|
||||
assert self.prefix_cache_stats is not None
|
||||
self.prefix_cache_stats.requests += 1
|
||||
self.prefix_cache_stats.queries += num_tokens
|
||||
self.prefix_cache_stats.hits += num_computed_tokens
|
||||
|
||||
def set_cache_stats(self, num_tokens, num_computed_tokens):
|
||||
assert self.prefix_cache_stats is not None
|
||||
self.prefix_cache_stats.hits = num_computed_tokens
|
||||
self.prefix_cache_stats.queries = num_tokens
|
||||
self.prefix_cache_stats.requests = 1
|
||||
|
||||
|
||||
class CPUKVCacheManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
num_cpu_blocks: int,
|
||||
caching_hash_algo: str = "builtin",
|
||||
use_eagle: bool = False,
|
||||
enable_kv_cache_events: bool = False,
|
||||
) -> None:
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.num_cpu_blocks = num_cpu_blocks
|
||||
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
|
||||
self.use_eagle = use_eagle
|
||||
self.block_pool = BlockPool(self.num_cpu_blocks, True,
|
||||
enable_kv_cache_events)
|
||||
self.single_type_manager = get_manager_for_kv_cache_spec(
|
||||
kv_cache_spec=kv_cache_spec,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_group_id=0,
|
||||
)
|
||||
# Record kv block hashes, avoid redundant computation.
|
||||
self.req_to_block_hashes: defaultdict[
|
||||
str, list[BlockHash]] = defaultdict(list)
|
||||
# Record blocks touched in get_matched_num_and_touch().
|
||||
self.req_to_computed_blocks: defaultdict[
|
||||
str, list[KVCacheBlock]] = defaultdict(list)
|
||||
# Record the request that failed to allocate.
|
||||
self.req_failed_to_allocate: defaultdict[str, bool] = defaultdict(bool)
|
||||
self.req_to_num_tokens: defaultdict[str, int] = defaultdict(int)
|
||||
self.cpu_cache_stats = CPUCacheStats(enable_prefix_caching=True,
|
||||
log_stats=True)
|
||||
# Record request that will be free after finish sending
|
||||
self.req_to_free: defaultdict[str, Request] = defaultdict(Request)
|
||||
|
||||
def get_matched_num_and_touch(self, request: Request) -> tuple[int, bool]:
|
||||
# When the request requires prompt logprobs, we skip prefix caching.
|
||||
if (request.sampling_params.prompt_logprobs is not None):
|
||||
return 0, False
|
||||
request_id = request.request_id
|
||||
# The block hashes for the request may already be computed
|
||||
# if the scheduler has tried to schedule the request before.
|
||||
block_hashes = self.req_to_block_hashes[request_id]
|
||||
if not block_hashes:
|
||||
block_hashes = request.block_hashes
|
||||
self.req_to_block_hashes[request_id] = block_hashes
|
||||
max_cache_hit_length = request.num_tokens - 1
|
||||
computed_blocks = self.single_type_manager.find_longest_cache_hit(
|
||||
block_hashes=block_hashes,
|
||||
max_length=max_cache_hit_length,
|
||||
kv_cache_group_ids=[0],
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.single_type_manager.kv_cache_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
)
|
||||
num_computed_tokens = len(computed_blocks[0]) * self.block_size
|
||||
self.req_to_computed_blocks[request_id] = computed_blocks[0]
|
||||
# We should touch these blocks in the concurrent scenarios.
|
||||
self.block_pool.touch(computed_blocks)
|
||||
|
||||
# cup prefix cache status set and log
|
||||
assert self.cpu_cache_stats is not None and self.cpu_cache_stats.prefix_cache_stats is not None
|
||||
self.cpu_cache_stats.set_cache_stats(request.num_tokens,
|
||||
num_computed_tokens)
|
||||
self.cpu_cache_stats.cpu_prefix_cache_metrics.observe(
|
||||
self.cpu_cache_stats.prefix_cache_stats)
|
||||
self.cpu_cache_stats.log()
|
||||
|
||||
return num_computed_tokens, False
|
||||
|
||||
def _release_ahead_touch(self, request_id: str):
|
||||
computed_blocks = self.req_to_computed_blocks[request_id]
|
||||
if computed_blocks:
|
||||
self.single_type_manager.block_pool.free_blocks(
|
||||
reversed(computed_blocks))
|
||||
self.req_to_computed_blocks.pop(request_id, None)
|
||||
|
||||
def allocate_slots(self, req_to_num_tokens: dict[str, int],
|
||||
unallocated_req_ids: set[str]) -> dict[str, list[int]]:
|
||||
for request_id in unallocated_req_ids:
|
||||
self._free_slots(request_id)
|
||||
req_to_new_blocks = {}
|
||||
for request_id, num_tokens in req_to_num_tokens.items():
|
||||
if self.req_failed_to_allocate[request_id]:
|
||||
continue
|
||||
new_computed_blocks = self.req_to_computed_blocks[request_id]
|
||||
num_blocks_to_allocate = (
|
||||
self.single_type_manager.get_num_blocks_to_allocate(
|
||||
request_id=request_id,
|
||||
num_tokens=num_tokens,
|
||||
new_computed_blocks=new_computed_blocks,
|
||||
))
|
||||
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
|
||||
self._release_ahead_touch(request_id)
|
||||
self.req_failed_to_allocate[request_id] = True
|
||||
continue
|
||||
# Append the new computed blocks to the request blocks until now to
|
||||
# avoid the case where the new blocks cannot be allocated.
|
||||
self.single_type_manager.save_new_computed_blocks(
|
||||
request_id, new_computed_blocks)
|
||||
# Allocate new blocks but do not cache now.
|
||||
new_blocks = self.single_type_manager.allocate_new_blocks(
|
||||
request_id, num_tokens)
|
||||
self.req_to_num_tokens[request_id] = num_tokens
|
||||
# No need to release ref_cnt because we use officially.
|
||||
self.req_to_computed_blocks.pop(request_id, None)
|
||||
req_to_new_blocks[request_id] = [
|
||||
block.block_id for block in new_computed_blocks + new_blocks
|
||||
]
|
||||
return req_to_new_blocks
|
||||
|
||||
def record_request_cache_and_free_slots(self, request: Request):
|
||||
logger.debug(
|
||||
f"record_request_cache_and_free_slots for request {request.request_id} in cpu_kv_cache_manager"
|
||||
)
|
||||
self.req_to_free[request.request_id] = request
|
||||
|
||||
def cache_and_free_slots(self, request_id: str):
|
||||
logger.debug(
|
||||
f"Cache and free slots for request {request_id} in cpu_kv_cache_manager"
|
||||
)
|
||||
if request_id not in self.req_to_free:
|
||||
logger.Error(
|
||||
f"request {request_id} not in req_to_free, maybe bug!")
|
||||
return
|
||||
request = self.req_to_free[request_id]
|
||||
if not self.req_failed_to_allocate[request_id]:
|
||||
self.single_type_manager.cache_blocks(
|
||||
request,
|
||||
self.req_to_num_tokens[request_id],
|
||||
)
|
||||
self._free_slots(request_id)
|
||||
logger.debug(
|
||||
f"delete request {request_id} in cpu_kv_cache_manager req_to_free")
|
||||
del self.req_to_free[request_id]
|
||||
|
||||
def _free_slots(self, request_id: str):
|
||||
# This function is designed to be reentrant.
|
||||
self._release_ahead_touch(request_id)
|
||||
self.single_type_manager.free(request_id)
|
||||
self.req_to_block_hashes.pop(request_id, None)
|
||||
self.req_to_computed_blocks.pop(request_id, None)
|
||||
self.req_failed_to_allocate.pop(request_id, None)
|
||||
self.req_to_num_tokens.pop(request_id, None)
|
||||
269
vllm_ascend/distributed/cpu_offload_manager/metadata.py
Normal file
269
vllm_ascend/distributed/cpu_offload_manager/metadata.py
Normal file
@@ -0,0 +1,269 @@
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import vllm.envs as envs
|
||||
import zmq
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.utils import get_dtype_size, logger, make_zmq_socket
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
from vllm_ascend.distributed.cpu_offload_manager.cpu_kv_cache_manager import \
|
||||
CPUKVCacheManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLAConfig:
|
||||
nope_dim: int
|
||||
rope_dim: int
|
||||
|
||||
|
||||
def get_cpu_offload_connector(vllm_config: VllmConfig) -> KVTransferConfig:
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
kv_transfer_config = vllm_config.kv_transfer_config
|
||||
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
|
||||
return kv_transfer_config
|
||||
elif kv_transfer_config.kv_connector == "MultiConnector":
|
||||
ktcs = kv_transfer_config.kv_connector_extra_config.get(
|
||||
"connectors")
|
||||
for ktc in ktcs:
|
||||
kv_transfer_config = KVTransferConfig(**ktc)
|
||||
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
|
||||
return kv_transfer_config
|
||||
return None
|
||||
|
||||
|
||||
class MetadataServer:
|
||||
METADATA_SERVER_ADDRESS = f"ipc://{envs.VLLM_RPC_BASE_PATH}/metadata.ipc"
|
||||
DEFAULT_CPU_SWAP_SPACE_GB = 800
|
||||
|
||||
class ZMQRPCClient:
|
||||
|
||||
def __init__(self, identity=f"worker-{os.getpid()}"):
|
||||
logger.info(f"metadata client for worker {identity} started")
|
||||
self.ctx = zmq.Context() # type: ignore
|
||||
self.socket = make_zmq_socket(
|
||||
self.ctx,
|
||||
MetadataServer.METADATA_SERVER_ADDRESS,
|
||||
zmq.DEALER, # type: ignore
|
||||
bind=False,
|
||||
identity=identity.encode(),
|
||||
linger=0)
|
||||
|
||||
def call(self, func_name: str, *args, **kwargs) -> Any:
|
||||
request = (func_name, args, kwargs)
|
||||
self.socket.send(b"", zmq.SNDMORE) # type: ignore
|
||||
self.socket.send(pickle.dumps(request))
|
||||
_ = self.socket.recv()
|
||||
response = pickle.loads(self.socket.recv())
|
||||
result, error = response
|
||||
if error:
|
||||
logger.exception(f"call metadata sever error: {error}")
|
||||
raise error
|
||||
if func_name == "init_cpu_kv_caches":
|
||||
(memory_dict, layer_size, layer_dtype, mla_config) = result
|
||||
# shared_memory_dict is recorded in self to close
|
||||
self.shared_memory_dict = memory_dict
|
||||
result = {}
|
||||
for key, shm in memory_dict.items():
|
||||
tensor = torch.frombuffer(
|
||||
shm.buf, dtype=layer_dtype).reshape(layer_size)
|
||||
if mla_config is not None:
|
||||
tensor = tensor.split(
|
||||
[mla_config.nope_dim, mla_config.rope_dim], dim=-1)
|
||||
result[key] = tensor
|
||||
return result
|
||||
|
||||
def __del__(self):
|
||||
# will be finalized by outer process
|
||||
self.socket.close()
|
||||
self.ctx.term()
|
||||
if hasattr(self, 'shared_memory_dict'):
|
||||
for shm in self.shared_memory_dict.values():
|
||||
shm.close()
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
self.world_size = vllm_config.parallel_config.world_size
|
||||
self.pipeline_parallel_size = vllm_config.parallel_config.pipeline_parallel_size
|
||||
kv_transfer_config = get_cpu_offload_connector(vllm_config)
|
||||
assert kv_transfer_config is not None
|
||||
available_memory_gb = kv_transfer_config.get_from_extra_config(
|
||||
"cpu_swap_space_gb", MetadataServer.DEFAULT_CPU_SWAP_SPACE_GB)
|
||||
self.available_memory = available_memory_gb * 1024 * 1024 * 1024
|
||||
logger.info(f"cpu swap space: {self.available_memory} bytes")
|
||||
self.ctx = zmq.Context() # type: ignore
|
||||
self.socket = make_zmq_socket(
|
||||
self.ctx,
|
||||
MetadataServer.METADATA_SERVER_ADDRESS,
|
||||
zmq.ROUTER, # type: ignore
|
||||
bind=True,
|
||||
linger=0)
|
||||
self.functions: dict[str, Callable] = {
|
||||
"init_cpu_kv_caches": self.init_cpu_kv_caches,
|
||||
"post_init": self.post_init,
|
||||
"ready": self.ready,
|
||||
}
|
||||
self.shared_memory = {} # type: ignore
|
||||
self.num_cpu_blocks = -1
|
||||
|
||||
@staticmethod
|
||||
def _safe_create_shared_memory(name: str, size: int) -> SharedMemory:
|
||||
try:
|
||||
existing_shm = SharedMemory(name=name, create=False)
|
||||
existing_shm.close()
|
||||
existing_shm.unlink()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return SharedMemory(name=name, create=True, size=size)
|
||||
|
||||
def ready(self):
|
||||
return True
|
||||
|
||||
def init_cpu_kv_caches(
|
||||
self,
|
||||
pp_rank: int,
|
||||
tp_rank: int,
|
||||
kv_cache_specs: dict[str, AttentionSpec],
|
||||
mla_config: MLAConfig,
|
||||
) -> tuple[dict[str, SharedMemory], tuple[int, ...], torch.dtype,
|
||||
MLAConfig]:
|
||||
logger.info(f"receive pp rank: {pp_rank}, tp rank: {tp_rank}")
|
||||
# follow the assumption that each layer has the same spec
|
||||
layer = next(iter(kv_cache_specs.values()))
|
||||
assert all([
|
||||
layer.page_size_bytes == any.page_size_bytes
|
||||
for any in kv_cache_specs.values()
|
||||
])
|
||||
# mla shares the same kv cache among different tp
|
||||
if layer.use_mla:
|
||||
tp_rank = 0
|
||||
if (pp_rank, tp_rank) in self.shared_memory:
|
||||
return self.shared_memory[(pp_rank, tp_rank)]
|
||||
available_memory = self.available_memory
|
||||
shared_memory_dict = {}
|
||||
if layer.use_mla:
|
||||
available_memory //= self.pipeline_parallel_size
|
||||
available_memory //= len(kv_cache_specs)
|
||||
num_blocks = available_memory // layer.page_size_bytes
|
||||
layer_size = (num_blocks, layer.block_size, layer.num_kv_heads,
|
||||
layer.head_size) # type: ignore
|
||||
else:
|
||||
available_memory //= self.world_size
|
||||
available_memory //= len(kv_cache_specs)
|
||||
num_blocks = available_memory // layer.page_size_bytes
|
||||
layer_size = (2, num_blocks, layer.block_size, layer.num_kv_heads,
|
||||
layer.head_size) # type: ignore
|
||||
nbytes = math.prod(layer_size) * get_dtype_size(layer.dtype)
|
||||
for layer_name in kv_cache_specs.keys():
|
||||
# only this format can share during ZeroMQ+pickle
|
||||
shared_memory_dict[
|
||||
layer_name] = MetadataServer._safe_create_shared_memory(
|
||||
f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes)
|
||||
if layer.use_mla:
|
||||
assert mla_config is not None
|
||||
assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim
|
||||
self.shared_memory[(pp_rank,
|
||||
tp_rank)] = (shared_memory_dict, layer_size,
|
||||
layer.dtype, mla_config)
|
||||
else:
|
||||
self.shared_memory[(pp_rank,
|
||||
tp_rank)] = (shared_memory_dict, layer_size,
|
||||
layer.dtype, None)
|
||||
if self.num_cpu_blocks == -1 or num_blocks < self.num_cpu_blocks:
|
||||
self.num_cpu_blocks = num_blocks
|
||||
self.layer = layer
|
||||
return self.shared_memory[(pp_rank, tp_rank)]
|
||||
|
||||
def post_init(self):
|
||||
# different processors in data parallel may call multiple times
|
||||
if hasattr(self, 'cpu_block_manager'):
|
||||
return
|
||||
# do shared_memory() at least once
|
||||
logger.info(f"assign cpu num blocks: {self.num_cpu_blocks}")
|
||||
assert self.num_cpu_blocks >= 0
|
||||
self.cpu_block_manager = CPUKVCacheManager(self.layer,
|
||||
self.num_cpu_blocks)
|
||||
self.functions.update({
|
||||
"get_matched_num_and_touch":
|
||||
self.cpu_block_manager.get_matched_num_and_touch,
|
||||
"allocate_slots":
|
||||
self.cpu_block_manager.allocate_slots,
|
||||
"record_request_cache_and_free_slots":
|
||||
self.cpu_block_manager.record_request_cache_and_free_slots,
|
||||
"cache_and_free_slots":
|
||||
self.cpu_block_manager.cache_and_free_slots,
|
||||
})
|
||||
|
||||
def serve_step(self):
|
||||
client_id = self.socket.recv()
|
||||
_ = self.socket.recv()
|
||||
raw_msg = self.socket.recv()
|
||||
try:
|
||||
func_name, args, kwargs = pickle.loads(raw_msg)
|
||||
except Exception as e:
|
||||
response = (None, Exception(f"Invalid request: {str(e)}"))
|
||||
else:
|
||||
if func_name in self.functions:
|
||||
try:
|
||||
result = self.functions[func_name](*args, **kwargs)
|
||||
response = (result, None) # type: ignore
|
||||
except Exception as e:
|
||||
logger.exception(f"metadata execute error: {e}")
|
||||
response = (None, e) # type: ignore
|
||||
else:
|
||||
response = (None, NameError(f"Function {func_name} not found"))
|
||||
self.socket.send(client_id, zmq.SNDMORE) # type: ignore
|
||||
self.socket.send(b"", zmq.SNDMORE) # type: ignore
|
||||
self.socket.send(pickle.dumps(response))
|
||||
|
||||
def shutdown(self):
|
||||
self.socket.close()
|
||||
self.ctx.term()
|
||||
socket_path = MetadataServer.METADATA_SERVER_ADDRESS.replace(
|
||||
"ipc://", "")
|
||||
if os.path.exists(socket_path):
|
||||
os.remove(socket_path)
|
||||
for cached in self.shared_memory.values():
|
||||
for shm in cached[0].values():
|
||||
shm.close()
|
||||
shm.unlink()
|
||||
|
||||
|
||||
class MetadataServerProc:
|
||||
|
||||
@staticmethod
|
||||
def run_metadata_server(vllm_config: VllmConfig):
|
||||
if (not vllm_config.cache_config.enable_prefix_caching
|
||||
or get_cpu_offload_connector(vllm_config) is None):
|
||||
return
|
||||
|
||||
shutdown_requested = False
|
||||
|
||||
def _signal_handler(signum, frame):
|
||||
nonlocal shutdown_requested
|
||||
if not shutdown_requested:
|
||||
shutdown_requested = True
|
||||
raise SystemExit()
|
||||
|
||||
# Either SIGTERM or SIGINT will terminate the worker
|
||||
# signal.signal(signal.SIGTERM, _signal_handler)
|
||||
# signal.signal(signal.SIGINT, _signal_handler)
|
||||
metadata_server: Optional[MetadataServer] = None
|
||||
try:
|
||||
metadata_server = MetadataServer(vllm_config)
|
||||
logger.info("Metadata server started.")
|
||||
while True:
|
||||
metadata_server.serve_step()
|
||||
except SystemExit:
|
||||
logger.info("Metadata server exiting.")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"Metadata server error: {e}.")
|
||||
raise e
|
||||
finally:
|
||||
if metadata_server is not None:
|
||||
metadata_server.shutdown()
|
||||
@@ -1,4 +1,5 @@
|
||||
import contextlib
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
@@ -17,6 +18,7 @@ import torch
|
||||
import zmq
|
||||
from llm_datadist import (BlocksCacheKey, CacheDesc, LLMConfig, LLMDataDist,
|
||||
LLMException, LLMRole)
|
||||
from vllm import envs
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
@@ -184,6 +186,7 @@ class LLMDataDistCMgrConnectorScheduler():
|
||||
self.port = dp_rank_local * tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT
|
||||
|
||||
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
|
||||
self._reqs_need_send: dict[str, float] = {}
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
@@ -248,7 +251,12 @@ class LLMDataDistCMgrConnectorScheduler():
|
||||
meta.add_new_req(request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
kv_transfer_params=req.kv_transfer_params)
|
||||
|
||||
meta.reqs_to_send = copy.deepcopy(self._reqs_need_send)
|
||||
|
||||
# Clear the list once workers start the transfers
|
||||
self._reqs_need_recv.clear()
|
||||
self._reqs_need_send.clear()
|
||||
|
||||
return meta
|
||||
|
||||
@@ -275,6 +283,9 @@ class LLMDataDistCMgrConnectorScheduler():
|
||||
if delay_free_blocks:
|
||||
logger.info("Delaying free of %d blocks for request %s",
|
||||
len(computed_block_ids), request.request_id)
|
||||
# Prefill request on remote. It will be read from D upon completion
|
||||
self._reqs_need_send[request.request_id] = time.perf_counter(
|
||||
) + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
|
||||
return delay_free_blocks, dict(
|
||||
do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
@@ -341,6 +352,7 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
os.environ["HCCL_DETERMINISTIC"] = "true"
|
||||
self.done_receiving_counts: defaultdict[str,
|
||||
set[int]] = defaultdict(set)
|
||||
self.reqs_to_send: dict[str, float] = {}
|
||||
|
||||
def listen_for_agent_metadata_req(self, event: threading.Event):
|
||||
assert self.local_agent_metadata is not None
|
||||
@@ -375,16 +387,13 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
)
|
||||
elif event_msg == LLMDataDistCMgrEvent.ReqForFinished:
|
||||
finished_req_id = decode_msg[0]
|
||||
decode_tp_rank = decode_msg[1]
|
||||
decode_tp_size = decode_msg[2]
|
||||
with self.thread_lock:
|
||||
if self._increment_task_count(finished_req_id,
|
||||
decode_tp_rank,
|
||||
decode_tp_size):
|
||||
logger.debug(
|
||||
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
|
||||
)
|
||||
logger.debug(
|
||||
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
|
||||
)
|
||||
if finished_req_id in self.reqs_to_send:
|
||||
self.finished_reqs.add(finished_req_id)
|
||||
del self.reqs_to_send[finished_req_id]
|
||||
sock.send_multipart(
|
||||
(identity, b"", b"receiving decode finished"))
|
||||
else:
|
||||
@@ -392,24 +401,6 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event {event_msg} from remote !"
|
||||
)
|
||||
|
||||
def _increment_task_count(self, request_id: str, tp_rank: int,
|
||||
decode_tp_size: int):
|
||||
if request_id not in self.done_receiving_counts:
|
||||
self.done_receiving_counts[request_id] = set()
|
||||
if tp_rank in self.done_receiving_counts[request_id]:
|
||||
logger.warning(
|
||||
f"Received duplicate done signal for request {request_id} "
|
||||
f"from tp rank {tp_rank}. Ignoring.")
|
||||
return False
|
||||
self.done_receiving_counts[request_id].add(tp_rank)
|
||||
if len(self.done_receiving_counts[request_id]) == decode_tp_size:
|
||||
self.done_receiving_counts.pop(request_id)
|
||||
logger.info("All transfers completed for request: "
|
||||
f"{request_id}. Total ranks: "
|
||||
f"{decode_tp_size}.")
|
||||
return True
|
||||
return False
|
||||
|
||||
def init_llm_datadist(self):
|
||||
assert self.local_agent_metadata is not None
|
||||
llm_config = LLMConfig()
|
||||
@@ -502,8 +493,11 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
assert self.local_agent_metadata is not None
|
||||
kv_cache_dtype = first_kv_cache.dtype
|
||||
self.use_mla: bool = first_kv_cache_tuple[0].size(
|
||||
-1) != first_kv_cache_tuple[1].size(-1)
|
||||
-1) != first_kv_cache_tuple[1].size(-1) and len(
|
||||
first_kv_cache_tuple) == 2
|
||||
self.use_sfa: bool = len(first_kv_cache_tuple) == 3
|
||||
# MLA case. [2 (k_normed, k_pe), num_blocks, ...]
|
||||
# SFA case. [3 (k_normed, k_pe, k_idx), num_blocks, ...]
|
||||
# MHA case. [2 (k and v), num_blocks, ...]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 3 # [block_size, latent_dim]
|
||||
@@ -549,6 +543,58 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
|
||||
)
|
||||
elif self.use_sfa:
|
||||
cache_k_normed_addr_list = []
|
||||
cache_k_pe_addr_list = []
|
||||
cache_k_idx_addr_list = []
|
||||
k_normed = None
|
||||
k_pe = None
|
||||
k_idx = None
|
||||
for cache_or_caches in kv_caches.values():
|
||||
assert len(cache_or_caches) > 1
|
||||
k_normed, k_pe, k_idx = cache_or_caches[0], cache_or_caches[
|
||||
1], cache_or_caches[2]
|
||||
cache_k_normed_addr_list.append(k_normed.data_ptr())
|
||||
cache_k_pe_addr_list.append(k_pe.data_ptr())
|
||||
cache_k_idx_addr_list.append(k_idx.data_ptr())
|
||||
self.cache_addr = (cache_k_normed_addr_list, cache_k_pe_addr_list,
|
||||
cache_k_idx_addr_list)
|
||||
|
||||
cache_desc_k_normed = CacheDesc(
|
||||
len(self.cache_addr[0]), [*k_normed.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
cache_desc_k_pe = CacheDesc(
|
||||
len(self.cache_addr[1]), [*k_pe.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
cache_desc_k_idx = CacheDesc(
|
||||
len(self.cache_addr[2]), [*k_idx.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
cache_key_k_normed = BlocksCacheKey(cluster_id=int(
|
||||
self.local_agent_metadata.cluster_id),
|
||||
model_id=0)
|
||||
cache_key_k_pe = BlocksCacheKey(cluster_id=int(
|
||||
self.local_agent_metadata.cluster_id),
|
||||
model_id=1)
|
||||
cache_key_k_idx = BlocksCacheKey(cluster_id=int(
|
||||
self.local_agent_metadata.cluster_id),
|
||||
model_id=2)
|
||||
self.cache_desc = (cache_desc_k_normed, cache_desc_k_pe,
|
||||
cache_desc_k_idx)
|
||||
self.cache_key = (cache_key_k_normed, cache_key_k_pe,
|
||||
cache_key_k_idx)
|
||||
try:
|
||||
cache_k_normed = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc[0], self.cache_addr[0], self.cache_key[0])
|
||||
cache_k_pe = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc[1], self.cache_addr[1], self.cache_key[1])
|
||||
cache_k_idx = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc[2], self.cache_addr[2], self.cache_key[2])
|
||||
self.cache = (cache_k_normed, cache_k_pe, cache_k_idx)
|
||||
logger.info("LLMDataDistWorker: End of register Paged Cache.")
|
||||
except (TypeError, ValueError):
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
|
||||
)
|
||||
else:
|
||||
for cache_or_caches in kv_caches.values():
|
||||
for cache in cache_or_caches:
|
||||
@@ -605,6 +651,7 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
|
||||
for future in futures:
|
||||
future.add_done_callback(handle_exception)
|
||||
self.reqs_to_send.update(metadata.reqs_to_send)
|
||||
|
||||
def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int:
|
||||
assert self.local_agent_metadata is not None
|
||||
@@ -767,24 +814,24 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
cluster_id = self.add_remote_agent(metadata)
|
||||
return cluster_id
|
||||
|
||||
def send_finish_to_remote(self, host: str, port: int, request_id):
|
||||
url = f"tcp://{host}:{port}"
|
||||
logger.debug(f"Sending finished to remote: {url}")
|
||||
msg_encoder = msgspec.msgpack.Encoder()
|
||||
msg_send = msg_encoder.encode([
|
||||
LLMDataDistCMgrEvent.ReqForFinished,
|
||||
[request_id, self.tp_rank, self.tp_size]
|
||||
])
|
||||
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
|
||||
try:
|
||||
sock.send(msg_send)
|
||||
logger.debug(
|
||||
f"Request id {request_id} finished message send to remote {url}"
|
||||
)
|
||||
_ = sock.recv()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to send reqest_id {request_id} to prefill: {e}")
|
||||
def send_finish_to_remote(self, host: str, ports: list[int], request_id):
|
||||
for port in ports:
|
||||
url = f"tcp://{host}:{port}"
|
||||
logger.debug(f"Sending finished to remote: {url}")
|
||||
msg_encoder = msgspec.msgpack.Encoder()
|
||||
msg_send = msg_encoder.encode(
|
||||
[LLMDataDistCMgrEvent.ReqForFinished, [request_id]])
|
||||
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
|
||||
try:
|
||||
sock.send(msg_send)
|
||||
logger.debug(
|
||||
f"Request id {request_id} finished message send to remote {url}"
|
||||
)
|
||||
_ = sock.recv()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to send reqest_id {request_id} to prefill: {e}"
|
||||
)
|
||||
|
||||
def _read_blocks(
|
||||
self,
|
||||
@@ -834,6 +881,38 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
raise RuntimeError(
|
||||
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
||||
)
|
||||
elif self.use_sfa:
|
||||
remote_cache_key_k_normed = BlocksCacheKey(
|
||||
cluster_id=remote_cluster_id, model_id=0)
|
||||
remote_cache_key_k_pe = BlocksCacheKey(
|
||||
cluster_id=remote_cluster_id, model_id=1)
|
||||
remote_cache_key_k_idx = BlocksCacheKey(
|
||||
cluster_id=remote_cluster_id, model_id=2)
|
||||
logger.info("Try pull blocks from remote server")
|
||||
try:
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key_k_normed,
|
||||
self.cache[0], # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key_k_pe,
|
||||
self.cache[1], # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key_k_idx,
|
||||
self.cache[2], # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
except (TypeError, ValueError):
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe} {remote_cache_key_k_idx}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
|
||||
)
|
||||
except LLMException:
|
||||
raise RuntimeError(
|
||||
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
||||
)
|
||||
else:
|
||||
remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id)
|
||||
logger.info("Try pull blocks from remote server")
|
||||
@@ -851,7 +930,10 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
raise RuntimeError(
|
||||
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
||||
)
|
||||
self.send_finish_to_remote(remote_ip, remote_port, request_id)
|
||||
remote_ports = list(
|
||||
range(remote_port + self.tp_rank,
|
||||
remote_port + int(remote_tp_size), self.tp_size))
|
||||
self.send_finish_to_remote(remote_ip, remote_ports, request_id)
|
||||
with self.thread_lock:
|
||||
self.finished_reqs.add(request_id)
|
||||
|
||||
@@ -859,8 +941,19 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
"""Get the finished recving and sending requuests."""
|
||||
import copy
|
||||
now = time.perf_counter()
|
||||
with self.thread_lock:
|
||||
while self.reqs_to_send:
|
||||
req_id, expires = next(iter(self.reqs_to_send.items()))
|
||||
if now < expires:
|
||||
break
|
||||
logger.warning(
|
||||
"Some requests in prefill node fail to receive KV Cache transfer done signal. "
|
||||
"If a greater mean TTFT is acceptable, you can 'export VLLM_NIXL_ABORT_REQUEST_TIMEOUT=600' (10 minutes) to relax the timeout condition. "
|
||||
)
|
||||
if req_id in self.reqs_to_send:
|
||||
self.finished_reqs.add(req_id)
|
||||
del self.reqs_to_send[req_id]
|
||||
req_ids_to_ret = copy.deepcopy(self.finished_reqs)
|
||||
self.finished_reqs.clear()
|
||||
if self.llm_datadist_role == LLMRole.PROMPT:
|
||||
@@ -891,4 +984,4 @@ def zmq_ctx(socket_type: Any,
|
||||
yield socket
|
||||
finally:
|
||||
if ctx is not None:
|
||||
ctx.destroy(linger=0)
|
||||
ctx.destroy(linger=0)
|
||||
@@ -1,556 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_ascend.distributed.communication_op import \
|
||||
data_parallel_reduce_scatter
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
|
||||
|
||||
|
||||
class MoECommMethod(ABC):
|
||||
"""Base class for MoE communication methods."""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
self.moe_config = moe_config
|
||||
|
||||
@abstractmethod
|
||||
def prepare(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Prepare the MoE communication method.
|
||||
|
||||
This method is called before quant_method.apply to prepare the
|
||||
communication method. It can be used to initialize any necessary
|
||||
resources or configurations.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""Finalize the MoE communication method.
|
||||
|
||||
This method is called after quant_method.apply to finalize the
|
||||
communication method. It can be used to clean up any resources or
|
||||
configurations.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def permute(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
num_experts: int,
|
||||
apply_a8_quantization: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
||||
"""Pre-process before MLP.
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): Tensor of shape (num_tokens, hidden_size)
|
||||
topk_ids (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
|
||||
topk_weights (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
|
||||
expert_map (torch.Tensor): Tensor of shape (global_num_experts, )
|
||||
Mapping from global expert IDs to local expert IDs.
|
||||
num_experts (int): Number of local experts (experts on this device).
|
||||
apply_a8_quantization (bool): Whether to apply A8 quantization (W4A8 and W8A8).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor, int]: Return a tuple containing:
|
||||
- permuted_hidden_states (torch.Tensor): Tensor of shape
|
||||
(num_tokens * top_k_num, hidden_size) after permuting
|
||||
hidden_states based on topk_ids.
|
||||
- expert_tokens (torch.Tensor): Tensor of shape (num_experts, )
|
||||
Number of tokens assigned to each expert.
|
||||
- dynamic_scale (torch.Tensor, optional): Tensor of shape (num_experts, )
|
||||
Dynamic scale for each expert, used for quantization.
|
||||
- group_list_type (int): Type of group list, 0 for `cumsum`
|
||||
and 1 for `count`. This is mainly for `npu_grouped_matmul`
|
||||
to determine how to handle the output.
|
||||
Raises:
|
||||
NotImplementedError: If the method is not implemented in the subclass.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unpermute(self, mlp_output: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
"""Post-process after MLP.
|
||||
|
||||
Args:
|
||||
mlp_output (torch.Tensor): Tensor of shape
|
||||
(num_tokens * top_k_num, hidden_size) after MLP.
|
||||
hidden_states (torch.Tensor): Tensor of shape
|
||||
(num_tokens, hidden_size) to be updated with the final output.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class AllGatherCommImpl(MoECommMethod):
|
||||
"""This implementation is the same as NativeAllGatherCommImpl,
|
||||
but uses NPU-specific ops for better performance.
|
||||
|
||||
This implementation should be compatible with all scenarios, and
|
||||
thus it is the default implementation for MoE communication methods.
|
||||
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
|
||||
and `torch_npu.npu_moe_token_unpermute` for post-processing
|
||||
to handle the token-to-expert mapping and communication efficiently.
|
||||
|
||||
NOTE(Yizhou): TBH, it is really weird that we were supposed to use
|
||||
`torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing`
|
||||
or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute`
|
||||
for pre-processing and post-processing, respectively.
|
||||
But `npu_moe_finalize_routing` will lead to accuracy issues so we have to
|
||||
use `torch_npu.npu_moe_token_unpermute` instead.
|
||||
This is a workaround and should be removed after the issue is fixed.
|
||||
"""
|
||||
|
||||
def prepare(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""When DP size > 1, pad the hidden states and router logits for communication."""
|
||||
if self.moe_config.dp_size > 1:
|
||||
forward_context = get_forward_context()
|
||||
max_tokens_across_dp = forward_context.max_tokens_across_dp
|
||||
|
||||
self.num_tokens = hidden_states.shape[0]
|
||||
pad_size = max_tokens_across_dp - self.num_tokens
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
hidden_states = self.moe_config.dp_group.all_gather(
|
||||
hidden_states, 0)
|
||||
router_logits = self.moe_config.dp_group.all_gather(
|
||||
router_logits, 0)
|
||||
|
||||
return hidden_states, router_logits
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""When DP size > 1, reduce-scatter the hidden states to get the final output.
|
||||
|
||||
When TP size > 1, all-reduce the hidden states to get the final output.
|
||||
"""
|
||||
if self.moe_config.dp_size > 1:
|
||||
hidden_states = data_parallel_reduce_scatter(hidden_states, dim=0)
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
|
||||
if reduce_results and (self.moe_config.tp_size > 1
|
||||
or self.moe_config.ep_size > 1):
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def permute(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor, # noqa: F841
|
||||
num_experts: int,
|
||||
apply_a8_quantization: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
|
||||
self.topk_weights = topk_weights
|
||||
self.topk_ids = topk_ids
|
||||
|
||||
first_expert_idx = 0
|
||||
if expert_map is not None:
|
||||
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
||||
# So we need to filter out invalid tokens by zeroing their weights.
|
||||
# This is a workaround and should be removed after the issue is fixed
|
||||
mask = expert_map[topk_ids] != -1
|
||||
# NOTE: This is equivalent to self.topk_weights[~mask] = 0.0,
|
||||
# but ~mask will dispatch to aclnnNonzeroV2, which is not supported in ACL Graph
|
||||
self.topk_weights = torch.where(mask, topk_weights, 0.0)
|
||||
|
||||
first_expert_idx = self.moe_config.ep_rank * num_experts
|
||||
last_expert_idx = first_expert_idx + num_experts
|
||||
|
||||
permuted_hidden_states, expanded_row_idx, expert_tokens, _ = (
|
||||
torch_npu.npu_moe_init_routing_v2(
|
||||
hidden_states,
|
||||
topk_ids,
|
||||
active_num=num_tokens * self.moe_config.experts_per_token,
|
||||
expert_num=self.moe_config.num_experts,
|
||||
expert_tokens_num_type=1, # Only support `count` mode now
|
||||
expert_tokens_num_flag=True, # Output `expert_tokens`
|
||||
active_expert_range=[first_expert_idx, last_expert_idx],
|
||||
quant_mode=-1,
|
||||
))
|
||||
self.expanded_row_idx = expanded_row_idx
|
||||
permuted_hidden_states = permuted_hidden_states
|
||||
|
||||
group_list_type = 1 # `count` mode
|
||||
|
||||
return permuted_hidden_states, expert_tokens, None, group_list_type
|
||||
|
||||
def unpermute(self, mlp_output: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
hidden_states[:] = torch_npu.npu_moe_token_unpermute(
|
||||
permuted_tokens=mlp_output,
|
||||
sorted_indices=self.expanded_row_idx,
|
||||
probs=self.topk_weights)
|
||||
|
||||
|
||||
class NativeAllGatherCommImpl(AllGatherCommImpl):
|
||||
"""This implementation should be compatible with all scenarios.
|
||||
|
||||
Note that this implementation purely consists of native PyTorch ops
|
||||
and does not use any NPU-specific ops. So the performance may not be optimal.
|
||||
But it is a good fallback for scenarios where NPU-specific ops are not available.
|
||||
"""
|
||||
|
||||
def permute(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
num_experts: int,
|
||||
apply_a8_quantization: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
|
||||
# Generate token indices and flatten
|
||||
token_indices = torch.arange(num_tokens,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.int64)
|
||||
token_indices = (token_indices.unsqueeze(1).expand(
|
||||
-1, self.moe_config.experts_per_token).reshape(-1))
|
||||
|
||||
# Flatten token-to-expert mappings and map to local experts
|
||||
weights_flat = topk_weights.view(-1)
|
||||
experts_flat = topk_ids.view(-1)
|
||||
local_experts_flat = (expert_map[experts_flat]
|
||||
if expert_map is not None else experts_flat)
|
||||
|
||||
# Filter valid token-expert pairs
|
||||
mask = local_experts_flat != -1
|
||||
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
||||
# So we need to filter out invalid tokens by zeroing their weights.
|
||||
# This is a workaround and should be removed after the issue is fixed
|
||||
filtered_weights = torch.where(mask, weights_flat,
|
||||
torch.zeros_like(weights_flat)).to(
|
||||
topk_weights.dtype)
|
||||
filtered_experts = torch.where(
|
||||
mask,
|
||||
local_experts_flat,
|
||||
torch.full_like(local_experts_flat, num_experts),
|
||||
).to(topk_ids.dtype)
|
||||
|
||||
# Sort by local expert IDs
|
||||
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
|
||||
self.sorted_token_indices = token_indices[sort_indices]
|
||||
self.sorted_weights = filtered_weights[sort_indices]
|
||||
|
||||
# Compute token counts with minlength of num_experts
|
||||
# This is equivalent to but faster than:
|
||||
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
|
||||
token_counts = torch.zeros(num_experts + 1,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.int64)
|
||||
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
|
||||
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
|
||||
expert_tokens = token_counts[:num_experts]
|
||||
|
||||
# Rearrange hidden_states
|
||||
permuted_hidden_states = hidden_states[self.sorted_token_indices]
|
||||
|
||||
group_list_type = 1 # `count` mode
|
||||
|
||||
return permuted_hidden_states, expert_tokens, None, group_list_type
|
||||
|
||||
def unpermute(self, mlp_output: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
mlp_output = mlp_output * self.sorted_weights.unsqueeze(1)
|
||||
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
final_hidden_states.index_add_(0, self.sorted_token_indices,
|
||||
mlp_output)
|
||||
|
||||
hidden_states[:] = final_hidden_states
|
||||
|
||||
|
||||
class MC2CommImpl(MoECommMethod):
|
||||
"""This implementation is for the scenarios listed below:
|
||||
1. `enable_expert_parallel=True`.
|
||||
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
|
||||
3. `enable_expert_parallel=False` is not supported.
|
||||
|
||||
This implementation uses the MC2 communication method, which is optimized for
|
||||
Communication and Computation parallelism on Ascend devices.
|
||||
"""
|
||||
|
||||
def __init__(self, moe_config: Optional[FusedMoEConfig]):
|
||||
super().__init__(moe_config)
|
||||
|
||||
# NOTE: We do not need to use mc2_group's rank and world size
|
||||
# because ep_group and mc2_group basically have the same init params.
|
||||
# We only init another group because of the restriction of MC2:
|
||||
# "No other groups can be used in the same process as the MC2 group."
|
||||
self.mc2_comm_name = get_mc2_group().device_group._get_backend(
|
||||
torch.device("npu")).get_hccl_comm_name(self.moe_config.ep_rank)
|
||||
|
||||
# Feature flags
|
||||
self.enable_dispatch_v2 = hasattr(torch_npu,
|
||||
"npu_moe_distribute_dispatch_v2")
|
||||
self.is_ascend_a3 = get_ascend_soc_version() == AscendSocVersion.A3
|
||||
self.need_extra_args = self.is_ascend_a3
|
||||
self._restore_tp_across_dp()
|
||||
|
||||
def _restore_tp_across_dp(self):
|
||||
# NOTE: Since vLLM flatten tp across dp, we need to restore the original
|
||||
# tp_size and tp_rank.
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
def prepare(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""The target_pad_length is calculated in forward_context, here we pad the
|
||||
hidden states and router logits. And if TP size > 1, we also need to split
|
||||
the tensors accordingly.
|
||||
"""
|
||||
self.num_tokens, _ = hidden_states.shape
|
||||
forward_context = get_forward_context()
|
||||
self.mc2_mask = forward_context.mc2_mask
|
||||
target_pad_length = forward_context.padded_num_tokens
|
||||
pad_size = target_pad_length - self.num_tokens
|
||||
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
if self.tp_size > 1:
|
||||
split_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_router_logits = torch.tensor_split(router_logits,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_mc2_mask = torch.tensor_split(self.mc2_mask,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
self.split_hidden_states = split_hidden_states
|
||||
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
self.mc2_mask = split_mc2_mask[self.tp_rank]
|
||||
|
||||
return hidden_states, router_logits
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""If TP size > 1, all-gather the hidden states to get the final output.
|
||||
|
||||
Also, unpad the hidden states if needed.
|
||||
"""
|
||||
if self.tp_size > 1:
|
||||
dist.all_gather(list(self.split_hidden_states), hidden_states,
|
||||
self.moe_config.tp_group.device_group)
|
||||
hidden_states = torch.cat(self.split_hidden_states, dim=0)
|
||||
|
||||
if self.num_tokens < hidden_states.shape[0]:
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
|
||||
return hidden_states
|
||||
|
||||
def permute(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
num_experts: int,
|
||||
apply_a8_quantization: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
||||
# Store tensors needed for post_process
|
||||
self.topk_ids = topk_ids
|
||||
self.topk_weights = topk_weights.to(torch.float32)
|
||||
|
||||
dispatch_kwargs = {
|
||||
"x": hidden_states,
|
||||
"expert_ids": self.topk_ids,
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": self.moe_config.num_experts,
|
||||
"global_bs": 0,
|
||||
"scales": None,
|
||||
"quant_mode": 2 if apply_a8_quantization else 0,
|
||||
"group_ep": self.mc2_comm_name,
|
||||
"ep_world_size": self.moe_config.ep_size,
|
||||
"ep_rank_id": self.moe_config.ep_rank,
|
||||
}
|
||||
|
||||
if self.need_extra_args:
|
||||
dispatch_kwargs.update({
|
||||
"group_tp": self.mc2_comm_name,
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
if self.is_ascend_a3 and self.enable_dispatch_v2:
|
||||
dispatch_kwargs.update({
|
||||
"x_active_mask": self.mc2_mask,
|
||||
})
|
||||
|
||||
dispatch = torch_npu.npu_moe_distribute_dispatch_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch
|
||||
|
||||
(
|
||||
permuted_hidden_states,
|
||||
dynamic_scale,
|
||||
self.assist_info_for_combine,
|
||||
expert_tokens,
|
||||
self.ep_recv_counts,
|
||||
self.tp_recv_counts,
|
||||
) = dispatch(**dispatch_kwargs)[:6]
|
||||
|
||||
group_list_type = 1
|
||||
|
||||
return permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type
|
||||
|
||||
def unpermute(self, mlp_output: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
combine_kwargs = {
|
||||
"expand_x": mlp_output,
|
||||
"expert_ids": self.topk_ids,
|
||||
"expert_scales": self.topk_weights,
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": self.moe_config.num_experts,
|
||||
"global_bs": 0,
|
||||
"ep_send_counts": self.ep_recv_counts,
|
||||
"group_ep": self.mc2_comm_name,
|
||||
"ep_world_size": self.moe_config.ep_size,
|
||||
"ep_rank_id": self.moe_config.ep_rank,
|
||||
}
|
||||
|
||||
if self.enable_dispatch_v2:
|
||||
combine_kwargs[
|
||||
"assist_info_for_combine"] = self.assist_info_for_combine
|
||||
else:
|
||||
combine_kwargs["expand_idx"] = self.assist_info_for_combine
|
||||
|
||||
if self.need_extra_args:
|
||||
combine_kwargs.update({
|
||||
"tp_send_counts": self.tp_recv_counts,
|
||||
"group_tp": self.mc2_comm_name,
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
if self.is_ascend_a3 and self.enable_dispatch_v2:
|
||||
combine_kwargs.update({
|
||||
"x_active_mask": self.mc2_mask,
|
||||
})
|
||||
|
||||
combine = torch_npu.npu_moe_distribute_combine_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine
|
||||
|
||||
hidden_states[:] = combine(**combine_kwargs)
|
||||
|
||||
|
||||
class AlltoAllCommImpl(MoECommMethod):
|
||||
"""This implementation is for the scenarios listed below:
|
||||
1. `enable_expert_parallel=True`.
|
||||
2. `npu_grouped_matmul` is available.
|
||||
|
||||
This implementation uses all-to-all communication to exchange tokens
|
||||
between data parallel ranks before and after the MLP computation. It should
|
||||
have better performance than AllGatherCommImpl when DP size > 1.
|
||||
"""
|
||||
|
||||
def __init__(self, moe_config: Optional[FusedMoEConfig]):
|
||||
super().__init__(moe_config)
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
|
||||
get_token_dispatcher
|
||||
self.token_dispatcher = get_token_dispatcher(
|
||||
"TokenDispatcherWithAll2AllV")
|
||||
self._restore_tp_across_dp()
|
||||
|
||||
def _restore_tp_across_dp(self):
|
||||
# NOTE: Since vLLM flatten tp across dp, we need to restore the original
|
||||
# tp_size and tp_rank.
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
def prepare(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
self.num_tokens, _ = hidden_states.shape
|
||||
pad_size = self.tp_size - self.num_tokens
|
||||
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
if self.tp_size > 1:
|
||||
split_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_router_logits = torch.tensor_split(router_logits,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
self.split_hidden_states = split_hidden_states
|
||||
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
|
||||
return hidden_states, router_logits
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""If TP size > 1, all-gather the hidden states to get the final output.
|
||||
|
||||
Also, unpad the hidden states if needed.
|
||||
"""
|
||||
if self.tp_size > 1:
|
||||
dist.all_gather(list(self.split_hidden_states), hidden_states,
|
||||
self.moe_config.tp_group.device_group)
|
||||
hidden_states = torch.cat(self.split_hidden_states, dim=0)
|
||||
|
||||
if self.num_tokens < hidden_states.shape[0]:
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
|
||||
return hidden_states
|
||||
|
||||
def permute(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
num_experts: int,
|
||||
apply_a8_quantization: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
||||
results = self.token_dispatcher.token_dispatch(
|
||||
hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
None,
|
||||
log2phy=None,
|
||||
with_quant=apply_a8_quantization)
|
||||
return results["hidden_states"], results["group_list"], results[
|
||||
"dynamic_scale"], results["group_list_type"]
|
||||
|
||||
def unpermute(self, mlp_output: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
hidden_states[:] = self.token_dispatcher.token_combine(mlp_output)
|
||||
447
vllm_ascend/distributed/mooncake/config_data.py
Normal file
447
vllm_ascend/distributed/mooncake/config_data.py
Normal file
@@ -0,0 +1,447 @@
|
||||
import array
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
|
||||
KVConnectorMetadata
|
||||
from vllm.utils import cdiv, logger
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
|
||||
|
||||
@dataclass
|
||||
class MooncakeEngineMetadata:
|
||||
"""name of the LLM model"""
|
||||
|
||||
model_name: str
|
||||
""" world size when running under a distributed setting """
|
||||
world_size: int
|
||||
""" worker id when running under a distributed setting """
|
||||
worker_id: int
|
||||
""" the format of kv tensors """
|
||||
kv_dtype: torch.dtype
|
||||
""" the shape of kv tensors """
|
||||
""" (num_layer, 2, metadata.block_size, num_kv_head, head_size) """
|
||||
kv_shape: tuple[int, int, int, int, int]
|
||||
block_size: int = 128
|
||||
""" whether use MLA"""
|
||||
use_mla: bool = False
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class MooncakeEngineKey:
|
||||
model_name: str
|
||||
world_size: int
|
||||
worker_id: int
|
||||
chunk_hash: str
|
||||
|
||||
def __hash__(self):
|
||||
return hash((
|
||||
self.model_name,
|
||||
self.world_size,
|
||||
self.worker_id,
|
||||
self.chunk_hash,
|
||||
))
|
||||
|
||||
def to_string(self):
|
||||
return (f"{self.model_name}@{self.world_size}"
|
||||
f"@{self.worker_id}@{self.chunk_hash}")
|
||||
|
||||
def split_layers(self, num_layers: int) -> List["LayerMooncakeEngineKey"]:
|
||||
"""Split the key into multiple keys for each layer"""
|
||||
keys = []
|
||||
for layer_id in range(num_layers):
|
||||
keys.append(
|
||||
LayerMooncakeEngineKey(
|
||||
self.model_name,
|
||||
self.world_size,
|
||||
self.worker_id,
|
||||
self.chunk_hash,
|
||||
layer_id,
|
||||
))
|
||||
return keys
|
||||
|
||||
def to_dict(self):
|
||||
# Note(Kuntai): this is used for serializing CacheEngineKey via msgpack.
|
||||
return {
|
||||
"__type__": "CacheEngineKey",
|
||||
"model_name": self.model_name,
|
||||
"world_size": self.world_size,
|
||||
"worker_id": self.worker_id,
|
||||
"chunk_hash": self.chunk_hash,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d):
|
||||
return MooncakeEngineKey(
|
||||
model_name=d["model_name"],
|
||||
world_size=d["world_size"],
|
||||
worker_id=d["worker_id"],
|
||||
chunk_hash=d["chunk_hash"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class LayerMooncakeEngineKey(MooncakeEngineKey):
|
||||
"""A key for the layer cache engine"""
|
||||
|
||||
layer_id: int
|
||||
|
||||
def __hash__(self):
|
||||
return hash((
|
||||
self.model_name,
|
||||
self.world_size,
|
||||
self.worker_id,
|
||||
self.chunk_hash,
|
||||
self.layer_id,
|
||||
))
|
||||
|
||||
def to_string(self):
|
||||
return (f"{self.model_name}@{self.world_size}"
|
||||
f"@{self.worker_id}@{self.chunk_hash}@{self.layer_id}")
|
||||
|
||||
|
||||
class ChunkedTokenDatabase():
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metadata: MooncakeEngineMetadata,
|
||||
):
|
||||
self.metadata = metadata
|
||||
|
||||
def _make_key_by_hash(self,
|
||||
chunk_hash: str,
|
||||
layer_id: Optional[int] = None):
|
||||
assert self.metadata is not None
|
||||
return MooncakeEngineKey(
|
||||
self.metadata.model_name,
|
||||
self.metadata.world_size,
|
||||
self.metadata.worker_id,
|
||||
chunk_hash,
|
||||
)
|
||||
|
||||
def _hash(
|
||||
self,
|
||||
tokens: Union[torch.Tensor, List[int]],
|
||||
prefix_hash: str,
|
||||
) -> str:
|
||||
# TODO: change it to a more efficient hash function
|
||||
if isinstance(tokens, torch.Tensor):
|
||||
tokens_bytes = tokens.cpu().to(torch.uint32).numpy().tobytes()
|
||||
elif isinstance(tokens, list):
|
||||
tokens_bytes = array.array("I", tokens).tobytes()
|
||||
return hashlib.sha256(prefix_hash.encode("ascii") +
|
||||
tokens_bytes).hexdigest()
|
||||
|
||||
def _chunk_tokens(
|
||||
self,
|
||||
tokens: Union[torch.Tensor, List[int]],
|
||||
) -> Iterable[Union[torch.Tensor, List[int]]]:
|
||||
"""
|
||||
Chunk the tokens into chunks of size self.metadata.block_size.
|
||||
|
||||
:param tokens: the input tokens, with shape [seq_len]
|
||||
device: the target device after chunking
|
||||
|
||||
:return: a generator of chunks of tokens, each with
|
||||
shape [metadata.block_size]
|
||||
"""
|
||||
for i in range(0, len(tokens), self.metadata.block_size):
|
||||
yield tokens[i:i + self.metadata.block_size]
|
||||
|
||||
def _prefix_hash(
|
||||
self,
|
||||
token_chunks: Iterable[Union[torch.Tensor, List[int]]],
|
||||
) -> Iterable[str]:
|
||||
prefix_hash = ''
|
||||
for token_chunk in token_chunks:
|
||||
prefix_hash = self._hash(token_chunk, prefix_hash)
|
||||
yield prefix_hash
|
||||
|
||||
def process_tokens(
|
||||
self,
|
||||
tokens: Union[torch.Tensor, List[int]],
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> Iterable[Tuple[int, int, MooncakeEngineKey]]:
|
||||
"""Process the tokens and return the corresponding cache engine keys.
|
||||
|
||||
:param Union[torch.Tensor, List[int]] tokens: The tokens to process.
|
||||
|
||||
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
|
||||
have the same length as tokens. And the mask should ALWAYS be like
|
||||
FFFFFTTTTTTT, where True means the tokens needs to be matched,
|
||||
and the Falses will ALWAYS be at the PREFIX of the tensor.
|
||||
|
||||
:param bool make_key: Whether to make the cache engine key or not.
|
||||
If False, the hash value will be returned instead.
|
||||
|
||||
:returns: A iterable of tuples with three elements. The first element
|
||||
is the start index of the tokens for the key. The second element
|
||||
is the end index of the tokens for the key. The third element is
|
||||
the cache engine key (or hash) for the tokens.
|
||||
|
||||
:raises: ValueError if the number of Falses in the mask is not a
|
||||
multiple of the chunk size.
|
||||
"""
|
||||
if mask is not None:
|
||||
num_falses = mask.numel() - mask.long().sum().item()
|
||||
else:
|
||||
num_falses = 0
|
||||
|
||||
if num_falses % self.metadata.block_size != 0:
|
||||
raise ValueError(
|
||||
"The number of Falses in the mask is not a multiple of the chunk size."
|
||||
)
|
||||
total_len = len(tokens)
|
||||
|
||||
token_chunks = self._chunk_tokens(tokens)
|
||||
prefix_hashes = self._prefix_hash(token_chunks)
|
||||
|
||||
start_idx = 0
|
||||
for chunk_id, hash_val in enumerate(prefix_hashes):
|
||||
start_idx = chunk_id * self.metadata.block_size
|
||||
end_idx = min(start_idx + self.metadata.block_size, total_len)
|
||||
if start_idx < num_falses:
|
||||
continue
|
||||
else:
|
||||
yield start_idx, end_idx, self._make_key_by_hash(hash_val)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadSpec:
|
||||
# Number of tokens cached in vLLM
|
||||
vllm_cached_tokens: int
|
||||
# Number of tokens that are cached in mooncake
|
||||
mooncake_cached_tokens: int
|
||||
# Whether the scheduler allow us to load the tokens
|
||||
can_load: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class SaveSpec:
|
||||
# Skip already saved tokens
|
||||
skip_leading_tokens: int
|
||||
# Whether the scheduler allow us to save the tokens
|
||||
can_save: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestTracker:
|
||||
# Request id
|
||||
req_id: str
|
||||
|
||||
# The token ids that has been scheduled so far
|
||||
token_ids: list[int]
|
||||
|
||||
# The block ids that has been allocated so far
|
||||
# NOTE: allocated blocks could be more than the number of tokens
|
||||
# FIXME: need to check whether the block ids will be changed after
|
||||
# preemption
|
||||
allocated_block_ids: list[int]
|
||||
|
||||
# The number of tokens that has been savd
|
||||
num_saved_tokens: int = 0
|
||||
|
||||
@staticmethod
|
||||
def from_new_request(
|
||||
new_request: "NewRequestData",
|
||||
num_tokens_to_compute: int,
|
||||
) -> "RequestTracker":
|
||||
"""Create the request tracker from a new request.
|
||||
|
||||
Args:
|
||||
new_request (NewRequestData): the new request data.
|
||||
num_tokens_to_compute (int): the number of tokens that will
|
||||
be 'computed', including the `num_computed_tokens` (vLLM's
|
||||
local cache hit) and new tokens that will be scheduled.
|
||||
|
||||
"""
|
||||
# vLLM 0.9.0 update: request.block_ids changed from list[int] to
|
||||
# list[list[int]]
|
||||
# Need to check the type of request.block_ids
|
||||
|
||||
unfolded_block_ids = []
|
||||
|
||||
if not isinstance(new_request.block_ids[0], list):
|
||||
unfolded_block_ids = new_request.block_ids.copy()
|
||||
else:
|
||||
unfolded_block_ids = new_request.block_ids[0].copy()
|
||||
|
||||
return RequestTracker(
|
||||
req_id=new_request.req_id,
|
||||
token_ids=new_request.prompt_token_ids[:num_tokens_to_compute].
|
||||
copy(),
|
||||
allocated_block_ids=unfolded_block_ids,
|
||||
num_saved_tokens=0,
|
||||
)
|
||||
|
||||
def update(
|
||||
self,
|
||||
new_token_ids: list[int],
|
||||
new_block_ids: Union[tuple[list[int], ...], list[int]],
|
||||
) -> None:
|
||||
"""Update the request tracker when a running request is
|
||||
scheduled again
|
||||
"""
|
||||
|
||||
self.token_ids.extend(new_token_ids)
|
||||
|
||||
if len(new_block_ids) == 0:
|
||||
new_block_ids = []
|
||||
elif isinstance(new_block_ids, tuple):
|
||||
new_block_ids = new_block_ids[0]
|
||||
elif isinstance(new_block_ids, list):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported new_block_ids type {type(new_block_ids)}")
|
||||
self.allocated_block_ids.extend(new_block_ids)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
# Request id
|
||||
req_id: str
|
||||
# Request tokens
|
||||
token_ids: torch.Tensor
|
||||
|
||||
block_ids: list[int]
|
||||
# # Slot mapping if exchange for block_id
|
||||
# slot_mapping: torch.Tensor
|
||||
# Skip save or not
|
||||
save_spec: Optional[SaveSpec] = None
|
||||
# load_spec
|
||||
load_spec: Optional[LoadSpec] = None
|
||||
|
||||
is_last_chunk: Optional[bool] = None
|
||||
|
||||
@staticmethod
|
||||
def from_request_tracker(
|
||||
tracker: RequestTracker,
|
||||
block_size: int,
|
||||
load_spec: Optional[LoadSpec] = None,
|
||||
skip_save: Optional[bool] = False,
|
||||
is_last_chunk: Optional[bool] = None,
|
||||
discard_partial_chunks: bool = True,
|
||||
) -> Optional["ReqMeta"]:
|
||||
"""Create the request metadata from a request tracker.
|
||||
|
||||
Args:
|
||||
tracker (RequestTracker): the request tracker.
|
||||
block_size (int): the block size in vLLM.
|
||||
load_spec (Optional[LoadSpec]): the load spec for KV cache loading.
|
||||
skip_save (bool): whether to skip the save operation.
|
||||
discard_partial_chunks (bool): whether to discard partial chunks.
|
||||
|
||||
Returns:
|
||||
the request metadata if we need to perform load/save
|
||||
operations, None otherwise.
|
||||
"""
|
||||
input_token_ids = tracker.token_ids
|
||||
input_token_len = len(input_token_ids)
|
||||
|
||||
# For save operation: do not save if the following condition is met
|
||||
# 1. has already been saved before (num_saved_tokens > 0)
|
||||
# 2. number of unsaved tokens is not reached the chunk boundary
|
||||
skip_leading_tokens = tracker.num_saved_tokens
|
||||
chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) *
|
||||
block_size if discard_partial_chunks else 0)
|
||||
# Calculate number of tokens to save based on discard_partial_chunks
|
||||
# setting
|
||||
num_tokens_to_save = ((input_token_len // block_size * block_size)
|
||||
if discard_partial_chunks else input_token_len)
|
||||
|
||||
skip_save = skip_save or num_tokens_to_save < chunk_boundary
|
||||
if skip_save and load_spec is None:
|
||||
return None
|
||||
|
||||
# If we need to save, update the number of saved tokens
|
||||
if not skip_save:
|
||||
tracker.num_saved_tokens = num_tokens_to_save
|
||||
save_spec = SaveSpec(skip_leading_tokens, not skip_save)
|
||||
|
||||
# Calculate the token ids and slot mappings for load and save
|
||||
# OPTIMIZATION: pre-allocate the buffer for token ids and block ids
|
||||
token_ids = torch.tensor(input_token_ids)[:num_tokens_to_save]
|
||||
|
||||
# # For load operation: check whether the request is scheduled to load
|
||||
if load_spec is not None and load_spec.can_load:
|
||||
logger.debug(
|
||||
"Scheduled to load %d tokens for request %s",
|
||||
load_spec.mooncake_cached_tokens,
|
||||
tracker.req_id,
|
||||
)
|
||||
else:
|
||||
# Do not load if not in `can_load` state
|
||||
load_spec = None
|
||||
logger.debug(
|
||||
f"request:{tracker.req_id}, meta save spec:{save_spec}, meta load spec:{load_spec}"
|
||||
)
|
||||
return ReqMeta(
|
||||
req_id=tracker.req_id,
|
||||
token_ids=token_ids,
|
||||
block_ids=tracker.allocated_block_ids,
|
||||
save_spec=save_spec,
|
||||
load_spec=load_spec,
|
||||
is_last_chunk=is_last_chunk,
|
||||
)
|
||||
|
||||
|
||||
class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
def __init__(self, unfinished_request_ids):
|
||||
self.requests = []
|
||||
self.unfinished_request_ids = unfinished_request_ids
|
||||
|
||||
def add_request(self, req_meta: ReqMeta) -> None:
|
||||
"""Add a request to the metadata.
|
||||
|
||||
Args:
|
||||
req_meta (ReqMeta): the request metadata.
|
||||
"""
|
||||
self.requests.append(req_meta)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LasyerMultiBlockReqMeta:
|
||||
req_id: str
|
||||
keys: List[LayerMooncakeEngineKey]
|
||||
starts: List[int]
|
||||
ends: list[int]
|
||||
block_ids: list[int]
|
||||
layer_id: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class MooncakeStoreConfig:
|
||||
local_hostname: str
|
||||
metadata_server: str
|
||||
global_segment_size: int
|
||||
local_buffer_size: int
|
||||
protocol: str
|
||||
device_name: str
|
||||
master_server_address: str
|
||||
|
||||
@staticmethod
|
||||
def from_file(file_path: str) -> "MooncakeStoreConfig":
|
||||
with open(file_path) as file:
|
||||
config = json.load(file)
|
||||
return MooncakeStoreConfig(
|
||||
local_hostname=config.get("local_hostname"),
|
||||
metadata_server=config.get("metadata_server"),
|
||||
global_segment_size=config.get("global_segment_size", 3355443200),
|
||||
local_buffer_size=config.get("local_buffer_size", 1073741824),
|
||||
protocol=config.get("protocol", "tcp"),
|
||||
device_name=config.get("device_name", ""),
|
||||
master_server_address=config.get("master_server_address"))
|
||||
|
||||
@staticmethod
|
||||
def load_from_env() -> "MooncakeStoreConfig":
|
||||
config_path = os.getenv("MOONCAKE_CONFIG_PATH")
|
||||
if not config_path:
|
||||
raise ValueError(
|
||||
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
|
||||
return MooncakeStoreConfig.from_file(config_path)
|
||||
251
vllm_ascend/distributed/mooncake/kv_transfer.py
Normal file
251
vllm_ascend/distributed/mooncake/kv_transfer.py
Normal file
@@ -0,0 +1,251 @@
|
||||
import queue
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from vllm.utils import logger
|
||||
|
||||
from vllm_ascend.distributed.mooncake.config_data import (
|
||||
ChunkedTokenDatabase, LasyerMultiBlockReqMeta)
|
||||
from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore
|
||||
|
||||
|
||||
class KVTransferThread(threading.Thread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
|
||||
local_kv_caches_base_addr: list[int],
|
||||
token_database: ChunkedTokenDatabase, block_len: list[int],
|
||||
block_size: int, ready_event: threading.Event, name: str):
|
||||
super().__init__(daemon=True, name=name)
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
self.m_store = m_store
|
||||
self.ready_event = ready_event
|
||||
self.kv_caches_base_addr = local_kv_caches_base_addr
|
||||
self.block_len = block_len
|
||||
self.token_database = token_database
|
||||
self.block_size = block_size
|
||||
self.done_task_lock = threading.Lock()
|
||||
# TODO(jianzs): find a better way to detect MLA.
|
||||
self.use_mla = len(block_len) == 2
|
||||
|
||||
self.request_queue: queue.Queue[Any] = queue.Queue()
|
||||
# TODO(jianzs): make this configurable
|
||||
self.executor = ThreadPoolExecutor(max_workers=32)
|
||||
self.finished_requests: set[str] = set()
|
||||
|
||||
def prepare_value(self, start: int, end: int, block_ids: list[int]):
|
||||
addr_list = []
|
||||
size_list = []
|
||||
block_id = block_ids[start // self.block_size]
|
||||
for index, base_addr in enumerate(self.kv_caches_base_addr):
|
||||
block_len = (self.block_len[index % 2]
|
||||
if self.use_mla else self.block_len[0])
|
||||
|
||||
addr = base_addr + block_id * block_len
|
||||
length = int(block_len / self.block_size * (end - start))
|
||||
addr_list.append(addr)
|
||||
size_list.append(length)
|
||||
return addr_list, size_list, block_id
|
||||
|
||||
def prepare_value_layer(self, start: int, end: int, block_ids: list[int],
|
||||
layer_id: int):
|
||||
block_id = block_ids[start // self.block_size]
|
||||
if self.use_mla:
|
||||
addr_k = self.kv_caches_base_addr[layer_id *
|
||||
2] + block_id * self.block_len[0]
|
||||
addr_v = self.kv_caches_base_addr[layer_id * 2 +
|
||||
1] + block_id * self.block_len[1]
|
||||
length_k = int(self.block_len[0] / self.block_size * (end - start))
|
||||
length_v = int(self.block_len[1] / self.block_size * (end - start))
|
||||
size_list = [length_k, length_v]
|
||||
else:
|
||||
addr_k = self.kv_caches_base_addr[layer_id *
|
||||
2] + block_id * self.block_len[0]
|
||||
addr_v = self.kv_caches_base_addr[layer_id * 2 +
|
||||
1] + block_id * self.block_len[0]
|
||||
length = int(self.block_len[0] / self.block_size * (end - start))
|
||||
size_list = [length, length]
|
||||
addr_list = [addr_k, addr_v]
|
||||
return addr_list, size_list
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_id: str,
|
||||
tokens: torch.Tensor,
|
||||
block_ids: list[int],
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
is_last_chunk: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
req = ({
|
||||
"req_id": req_id,
|
||||
"tokens": tokens,
|
||||
"block_ids": block_ids,
|
||||
"mask": mask,
|
||||
"is_last_chunk": is_last_chunk,
|
||||
})
|
||||
self.request_queue.put(req)
|
||||
|
||||
def get_and_clear_finished_requests(self) -> set[str]:
|
||||
"""
|
||||
Get and clear the requests that have been completed.
|
||||
Returns:
|
||||
A set of request IDs that have been completed.
|
||||
"""
|
||||
with self.done_task_lock:
|
||||
finished_requests = self.finished_requests.copy()
|
||||
self.finished_requests.clear()
|
||||
return finished_requests
|
||||
|
||||
def set_finished_request(self, req_id):
|
||||
with self.done_task_lock:
|
||||
self.finished_requests.add(req_id)
|
||||
|
||||
def run(self):
|
||||
"""Run the thread to handle KV cache transfer requests."""
|
||||
self.ready_event.set()
|
||||
while True:
|
||||
try:
|
||||
request_data = self.request_queue.get()
|
||||
if request_data is None:
|
||||
logger.warning("Received a None request!")
|
||||
self.request_queue.task_done()
|
||||
continue
|
||||
self._handle_request(request_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in KVCacheTransferThread: {e}")
|
||||
|
||||
def _handle_request(self, req_meta: dict[str, Any]):
|
||||
pass
|
||||
|
||||
|
||||
class KVCacheStoreSendingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
|
||||
local_kv_caches_base_addr: list[int],
|
||||
token_database: ChunkedTokenDatabase, block_len: list[int],
|
||||
block_size: int, ready_event: threading.Event):
|
||||
super().__init__(tp_rank,
|
||||
tp_size,
|
||||
m_store,
|
||||
local_kv_caches_base_addr,
|
||||
token_database,
|
||||
block_len,
|
||||
block_size,
|
||||
ready_event,
|
||||
name="KVCacheSendingThread")
|
||||
|
||||
def _handle_request(self, req_meta: dict[str, Any]):
|
||||
tokens = req_meta["tokens"]
|
||||
mask = req_meta["mask"]
|
||||
block_ids = req_meta["block_ids"]
|
||||
req_id = req_meta["req_id"]
|
||||
is_last_chunk = req_meta["is_last_chunk"]
|
||||
torch.npu.current_stream().synchronize()
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
addr, size, _ = self.prepare_value(start, end, block_ids)
|
||||
self.m_store.put(key, addr, size)
|
||||
if is_last_chunk:
|
||||
self.set_finished_request(req_id)
|
||||
self.request_queue.task_done()
|
||||
|
||||
|
||||
class KVCacheStoreRecvingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
|
||||
local_kv_caches_base_addr: list[int],
|
||||
token_database: ChunkedTokenDatabase, block_len: list[int],
|
||||
block_size: int, ready_event: threading.Event):
|
||||
super().__init__(tp_rank,
|
||||
tp_size,
|
||||
m_store,
|
||||
local_kv_caches_base_addr,
|
||||
token_database,
|
||||
block_len,
|
||||
block_size,
|
||||
ready_event,
|
||||
name="KVCacheStoreRecvingThread")
|
||||
|
||||
def _handle_request(self, req_meta: dict[str, Any]):
|
||||
tokens = req_meta["tokens"]
|
||||
mask = req_meta["mask"]
|
||||
block_ids = req_meta["block_ids"]
|
||||
req_id = req_meta["req_id"]
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
addr, size, _ = self.prepare_value(start, end, block_ids)
|
||||
self.m_store.get(key, addr, size)
|
||||
self.set_finished_request(req_id)
|
||||
self.request_queue.task_done()
|
||||
|
||||
|
||||
class KVCacheStoreLayerSendingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
|
||||
local_kv_caches_base_addr: list[int],
|
||||
token_database: ChunkedTokenDatabase, block_len: list[int],
|
||||
block_size: int, ready_event: threading.Event,
|
||||
num_layers: int):
|
||||
super().__init__(tp_rank,
|
||||
tp_size,
|
||||
m_store,
|
||||
local_kv_caches_base_addr,
|
||||
token_database,
|
||||
block_len,
|
||||
block_size,
|
||||
ready_event,
|
||||
name="KVCacheStoreLayerSendingThread")
|
||||
self.final_layer_id = num_layers - 1
|
||||
|
||||
def add_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
|
||||
self.request_queue.put(req_meta)
|
||||
|
||||
def _handle_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta):
|
||||
torch.npu.current_stream().synchronize()
|
||||
for index, key in enumerate(req_meta.keys):
|
||||
addr, size = self.prepare_value_layer(req_meta.starts[index],
|
||||
req_meta.ends[index],
|
||||
req_meta.block_ids,
|
||||
req_meta.layer_id)
|
||||
self.m_store.put(key, addr, size)
|
||||
if req_meta.layer_id == self.final_layer_id:
|
||||
self.set_finished_request(req_meta.req_id)
|
||||
self.request_queue.task_done()
|
||||
|
||||
|
||||
class KVCacheStoreLayerRecvingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
|
||||
local_kv_caches_base_addr: list[int],
|
||||
token_database: ChunkedTokenDatabase, block_len: list[int],
|
||||
block_size: int, ready_event: threading.Event,
|
||||
get_event: threading.Event):
|
||||
super().__init__(tp_rank,
|
||||
tp_size,
|
||||
m_store,
|
||||
local_kv_caches_base_addr,
|
||||
token_database,
|
||||
block_len,
|
||||
block_size,
|
||||
ready_event,
|
||||
name="KVCacheStoreLayerRecvingThread")
|
||||
self.get_event = get_event
|
||||
|
||||
def add_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
|
||||
self.request_queue.put(req_meta)
|
||||
|
||||
def _handle_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta):
|
||||
for index, key in enumerate(req_meta.keys):
|
||||
addr, size = self.prepare_value_layer(req_meta.starts[index],
|
||||
req_meta.ends[index],
|
||||
req_meta.block_ids,
|
||||
req_meta.layer_id)
|
||||
self.m_store.get(key, addr, size)
|
||||
self.request_queue.task_done()
|
||||
self.get_event.set()
|
||||
489
vllm_ascend/distributed/mooncake/mooncake_engine.py
Normal file
489
vllm_ascend/distributed/mooncake/mooncake_engine.py
Normal file
@@ -0,0 +1,489 @@
|
||||
# Standard
|
||||
import math
|
||||
import threading
|
||||
import time
|
||||
from typing import Generator, List, Optional, Union
|
||||
|
||||
# Third Party
|
||||
import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import get_kv_cache_torch_dtype, logger
|
||||
|
||||
from vllm_ascend.distributed.mooncake.config_data import (
|
||||
ChunkedTokenDatabase, LasyerMultiBlockReqMeta, MooncakeConnectorMetadata,
|
||||
MooncakeEngineMetadata)
|
||||
from vllm_ascend.distributed.mooncake.kv_transfer import (
|
||||
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
|
||||
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
|
||||
from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore
|
||||
|
||||
|
||||
class MooncakeEngine:
|
||||
#The main class for the cache engine.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
use_layerwize: bool,
|
||||
):
|
||||
model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.use_mla = False
|
||||
if (hasattr(model_config, "use_mla")
|
||||
and isinstance(model_config.use_mla, bool)
|
||||
and model_config.use_mla):
|
||||
self.use_mla = True
|
||||
self.use_layerwise = use_layerwize
|
||||
self.tp_rank = parallel_config.rank
|
||||
self.tp_size = parallel_config.tensor_parallel_size
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.current_layer = 0
|
||||
# self.use_mla = first_kv_cache_tuple[0].size(
|
||||
# -1) != first_kv_cache_tuple[1].size(-1)
|
||||
self.num_layers = model_config.get_num_layers(parallel_config)
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
num_kv_head = model_config.get_num_kv_heads(parallel_config)
|
||||
head_size = model_config.get_head_size()
|
||||
kv_dtype = get_kv_cache_torch_dtype(
|
||||
vllm_config.cache_config.cache_dtype, model_config.dtype)
|
||||
self.hidden_dim_size = num_kv_head * head_size
|
||||
if self.use_mla:
|
||||
kv_shape = (self.num_layers, 1, self.block_size, 1, head_size)
|
||||
else:
|
||||
kv_shape = (self.num_layers, 2, self.block_size, num_kv_head,
|
||||
head_size)
|
||||
self.metadata = MooncakeEngineMetadata(
|
||||
model_config.model,
|
||||
parallel_config.world_size,
|
||||
parallel_config.rank,
|
||||
kv_dtype,
|
||||
kv_shape,
|
||||
self.block_size,
|
||||
self.use_mla,
|
||||
)
|
||||
|
||||
self.token_database = ChunkedTokenDatabase(self.metadata)
|
||||
|
||||
self.m_store = Mooncakestore(parallel_config)
|
||||
|
||||
self.kv_send_thread: Optional[KVTransferThread] = None
|
||||
self.kv_recv_thread: Optional[KVTransferThread] = None
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
|
||||
first_kv_cache = first_kv_cache_tuple[0]
|
||||
|
||||
# TODO(tms): Find a more robust way to detect and handle MLA
|
||||
if self.use_mla:
|
||||
# MLA case.[num_block, block_size, 1, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 3 # [block_size, latent_dim]
|
||||
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
||||
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
||||
self.block_len = [
|
||||
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
||||
first_kv_cache[1].element_size() * math.prod(block_shape_pe)
|
||||
]
|
||||
logger.info(
|
||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
||||
self.num_blocks, block_shape_norm, block_shape_pe)
|
||||
else:
|
||||
# [num_block, block_size, num_head, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
kv_elem_size = first_kv_cache.element_size()
|
||||
block_rank = 3 # [block_size, kv_heads, head_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
self.block_len = [kv_elem_size * math.prod(block_shape)]
|
||||
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
|
||||
block_shape)
|
||||
|
||||
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
|
||||
self.use_mla, first_kv_cache.shape)
|
||||
|
||||
self.kv_caches = kv_caches
|
||||
self.m_store.set_kv_caches(kv_caches.values())
|
||||
self.kv_caches_base_addr = []
|
||||
for cache_or_caches in kv_caches.values():
|
||||
# Normalize to always be a list of caches
|
||||
if self.use_mla:
|
||||
for i, cache in enumerate(cache_or_caches, 0):
|
||||
base_addr = cache.data_ptr()
|
||||
self.kv_caches_base_addr.append(base_addr)
|
||||
else:
|
||||
cache_list = [cache_or_caches
|
||||
] if self.use_mla else cache_or_caches
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
self.kv_caches_base_addr.append(base_addr)
|
||||
|
||||
if self.use_layerwise:
|
||||
self.get_event = threading.Event()
|
||||
if self.kv_role in ['kv_producer', 'kv_both']:
|
||||
ready_event_sending = threading.Event()
|
||||
self.kv_send_thread = KVCacheStoreLayerSendingThread(
|
||||
self.tp_rank, self.tp_size, self.m_store,
|
||||
self.kv_caches_base_addr, self.token_database,
|
||||
self.block_len, self.block_size, ready_event_sending,
|
||||
self.num_layers)
|
||||
self.kv_send_thread.start()
|
||||
ready_event = threading.Event()
|
||||
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
|
||||
self.tp_rank, self.tp_size, self.m_store,
|
||||
self.kv_caches_base_addr, self.token_database, self.block_len,
|
||||
self.block_size, ready_event, self.get_event)
|
||||
self.kv_recv_thread.start()
|
||||
ready_event.wait()
|
||||
else:
|
||||
if self.kv_role in ['kv_producer', 'kv_both']:
|
||||
ready_event_sending = threading.Event()
|
||||
self.kv_send_thread = KVCacheStoreSendingThread(
|
||||
self.tp_rank, self.tp_size, self.m_store,
|
||||
self.kv_caches_base_addr, self.token_database,
|
||||
self.block_len, self.block_size, ready_event_sending)
|
||||
self.kv_send_thread.start()
|
||||
ready_event = threading.Event()
|
||||
self.kv_recv_thread = KVCacheStoreRecvingThread(
|
||||
self.tp_rank, self.tp_size, self.m_store,
|
||||
self.kv_caches_base_addr, self.token_database, self.block_len,
|
||||
self.block_size, ready_event)
|
||||
self.kv_recv_thread.start()
|
||||
ready_event.wait()
|
||||
|
||||
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
|
||||
self.current_layer = 0
|
||||
self.layerwise_retrievers = []
|
||||
for request in metadata.requests:
|
||||
load_spec = request.load_spec
|
||||
if load_spec is None or not load_spec.can_load: #load =0
|
||||
continue
|
||||
tokens = request.token_ids
|
||||
req_id = request.req_id
|
||||
if (load_spec.mooncake_cached_tokens % self.block_size
|
||||
!= 0) and (load_spec.mooncake_cached_tokens
|
||||
== tokens.shape[0] - 1):
|
||||
tokens = tokens[:request.load_spec.mooncake_cached_tokens + 1]
|
||||
else:
|
||||
tokens = tokens[:request.load_spec.mooncake_cached_tokens]
|
||||
masked_token_count = (request.load_spec.vllm_cached_tokens //
|
||||
self.block_size * self.block_size)
|
||||
token_mask = torch.ones_like(tokens, dtype=torch.bool)
|
||||
token_mask[:masked_token_count] = False
|
||||
if self.use_layerwise:
|
||||
layerwise_retriever = self.retrieve_layer(
|
||||
req_id,
|
||||
tokens,
|
||||
request.block_ids,
|
||||
token_mask,
|
||||
)
|
||||
next(layerwise_retriever) # first layer load
|
||||
self.layerwise_retrievers.append(layerwise_retriever)
|
||||
else:
|
||||
self.kv_recv_thread.add_request( # type: ignore[union-attr]
|
||||
req_id,
|
||||
tokens,
|
||||
request.block_ids,
|
||||
token_mask,
|
||||
)
|
||||
|
||||
def wait_for_layer_load(self) -> None:
|
||||
"""MooncakeConnector does not do layerwise saving."""
|
||||
for layerwise_retriever in self.layerwise_retrievers:
|
||||
ret_token_mask = next(layerwise_retriever)
|
||||
if self.current_layer == self.num_layers - 1:
|
||||
assert ret_token_mask is not None
|
||||
num_retrieved_tokens = ret_token_mask.sum().item()
|
||||
logger.info(f"Retrieved {num_retrieved_tokens} tokens")
|
||||
|
||||
def save_kv_layer(self,
|
||||
connector_metadata: MooncakeConnectorMetadata) -> None:
|
||||
"""MooncakeConnector does not save explicitly."""
|
||||
if self.current_layer == 0:
|
||||
self.layerwise_storers = []
|
||||
for request in connector_metadata.requests:
|
||||
save_spec = request.save_spec
|
||||
if save_spec is None or not save_spec.can_save:
|
||||
continue
|
||||
|
||||
token_ids = request.token_ids
|
||||
req_id = request.req_id
|
||||
assert isinstance(token_ids, torch.Tensor)
|
||||
assert token_ids.is_cpu
|
||||
|
||||
# TODO: whether need to remov saveThread
|
||||
# no lookup, skipmask
|
||||
skip_leading_tokens = max(
|
||||
self.lookup(token_ids, self.use_layerwise),
|
||||
save_spec.skip_leading_tokens,
|
||||
)
|
||||
if skip_leading_tokens == len(token_ids):
|
||||
if request.is_last_chunk:
|
||||
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
|
||||
req_id)
|
||||
continue # skip this request
|
||||
|
||||
skip_leading_tokens = (skip_leading_tokens // self.block_size *
|
||||
self.block_size)
|
||||
|
||||
store_mask = torch.ones_like(token_ids, dtype=torch.bool)
|
||||
store_mask[:skip_leading_tokens] = False
|
||||
logger.info(
|
||||
"Storing KV cache for %d out of %d tokens "
|
||||
"(skip_leading_tokens=%d) for request %s",
|
||||
len(token_ids) - skip_leading_tokens,
|
||||
len(token_ids),
|
||||
skip_leading_tokens,
|
||||
request.req_id,
|
||||
)
|
||||
|
||||
layerwise_storer = self.store_layer(
|
||||
req_id,
|
||||
token_ids,
|
||||
mask=store_mask,
|
||||
block_ids=request.block_ids,
|
||||
)
|
||||
self.layerwise_storers.append(layerwise_storer)
|
||||
for layerwise_storer in self.layerwise_storers:
|
||||
try:
|
||||
next(layerwise_storer)
|
||||
except Exception:
|
||||
raise
|
||||
self.current_layer = self.current_layer + 1
|
||||
|
||||
def wait_for_save(self, connector_metadata: MooncakeConnectorMetadata):
|
||||
"""MooncakeConnector does not save explicitly."""
|
||||
for request in connector_metadata.requests:
|
||||
save_spec = request.save_spec
|
||||
if save_spec is None or not save_spec.can_save:
|
||||
continue
|
||||
|
||||
token_ids = request.token_ids
|
||||
req_id = request.req_id
|
||||
assert isinstance(token_ids, torch.Tensor)
|
||||
assert token_ids.is_cpu
|
||||
|
||||
skip_leading_tokens = max(
|
||||
self.lookup(token_ids, self.use_layerwise),
|
||||
save_spec.skip_leading_tokens,
|
||||
)
|
||||
if skip_leading_tokens == len(token_ids):
|
||||
if request.is_last_chunk:
|
||||
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
|
||||
req_id)
|
||||
continue # skip this request
|
||||
|
||||
skip_leading_tokens = (skip_leading_tokens // self.block_size *
|
||||
self.block_size)
|
||||
|
||||
store_mask = torch.ones_like(token_ids, dtype=torch.bool)
|
||||
store_mask[:skip_leading_tokens] = False
|
||||
|
||||
logger.info(
|
||||
"Storing KV cache for %d out of %d tokens "
|
||||
"(skip_leading_tokens=%d) for request %s",
|
||||
len(token_ids) - skip_leading_tokens,
|
||||
len(token_ids),
|
||||
skip_leading_tokens,
|
||||
request.req_id,
|
||||
)
|
||||
|
||||
self.kv_send_thread.add_request( # type: ignore[union-attr]
|
||||
req_id,
|
||||
token_ids,
|
||||
request.block_ids,
|
||||
store_mask,
|
||||
request.is_last_chunk,
|
||||
)
|
||||
|
||||
def retrieve_layer(
|
||||
self,
|
||||
req_id: str,
|
||||
tokens: torch.Tensor,
|
||||
block_ids: list[int],
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> Generator[Optional[torch.Tensor], None, None]:
|
||||
"""
|
||||
Retrieve the KV cache in a layerwise manner.
|
||||
|
||||
:param torch.Tensor tokens: The tokens of the corresponding KV caches.
|
||||
|
||||
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
|
||||
have the same length as tokens. And the mask should ALWAYS be like
|
||||
FFFFFTTTTTTT, where True means the tokens needs to be matched.
|
||||
|
||||
:param **kwargs: The additional arguments for the KV transfer which
|
||||
will be passed into the npu_transfer.
|
||||
|
||||
return: A generator that yields Optional[torch.Tensor]. The tensor will
|
||||
be the boolean mask indicating which tokens are retrieved and will
|
||||
only be returned in the last iteration.
|
||||
"""
|
||||
|
||||
if mask is not None:
|
||||
num_required_tokens = torch.sum(mask).item()
|
||||
else:
|
||||
num_required_tokens = len(tokens)
|
||||
|
||||
ret_mask = torch.zeros_like(tokens, dtype=torch.bool, device="cpu")
|
||||
|
||||
starts = []
|
||||
ends = []
|
||||
keys = []
|
||||
first_flag = True
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
starts.append(start)
|
||||
ends.append(end)
|
||||
keys.append(keys_multi_layer)
|
||||
ret_mask[start:end] = True
|
||||
|
||||
if keys:
|
||||
# Transpose the keys into layer major format
|
||||
keys = [list(row) for row in zip(*keys)] # [num_layer,block_num]
|
||||
for layer_id, keys_multi_chunk in enumerate(keys):
|
||||
if not first_flag:
|
||||
is_finish = self.get_event.wait(timeout=3) #try---cache
|
||||
if not is_finish:
|
||||
logger.info("Layerwise get failed")
|
||||
self.get_event.clear()
|
||||
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
|
||||
starts, ends, block_ids,
|
||||
layer_id)
|
||||
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
|
||||
req_meta) # type: ignore[union-attr, call-arg, arg-type]
|
||||
first_flag = False
|
||||
yield None
|
||||
else:
|
||||
# If no cache are found, we still need to yield to avoid
|
||||
# `StopIteration`
|
||||
for layer_id in range(self.num_layers):
|
||||
yield None
|
||||
|
||||
retrieved_tokens = torch.sum(ret_mask)
|
||||
logger.debug(f"Retrieved {retrieved_tokens} "
|
||||
f"out of {num_required_tokens} "
|
||||
f"out of total {len(tokens)} tokens")
|
||||
|
||||
yield ret_mask
|
||||
|
||||
def store_layer(
|
||||
self,
|
||||
req_id: str,
|
||||
tokens: torch.Tensor,
|
||||
block_ids: list[int],
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Store the KV cache in a layerwise manner.
|
||||
|
||||
:param torch.Tensor tokens: The tokens of the corresponding KV caches.
|
||||
|
||||
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
|
||||
have the same length as tokens. And the mask should ALWAYS be like
|
||||
FFFFFTTTTTTT, where True means the tokens needs to be matched.
|
||||
|
||||
:param **kwargs: The additional arguments for the storage backend which
|
||||
will be passed into the gpu_connector.
|
||||
|
||||
return: A generator that yields None. In the first iteration, the
|
||||
generator allocates the memory objects for all layers and moves
|
||||
the KV cache of the first layer from GPU to CPU. In the next
|
||||
iterations, it moves the KV cache of layer i from GPU to the memory
|
||||
objects (on CPU) and puts the memory objects of layer i-1 to the
|
||||
storage backends. In the last iteration, it puts the memory objects
|
||||
of the last layer to the storage backends.
|
||||
"""
|
||||
|
||||
if mask is not None:
|
||||
num_stored_tokens = torch.sum(mask).item()
|
||||
else:
|
||||
num_stored_tokens = len(tokens)
|
||||
|
||||
starts = []
|
||||
ends = []
|
||||
keys = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
starts.append(start)
|
||||
ends.append(end)
|
||||
keys.append(keys_multi_layer) #[block_num,layer_num]
|
||||
|
||||
if keys:
|
||||
keys = [list(row) for row in zip(*keys)] #[layer_num,block_num]
|
||||
for layer_id, keys_multi_chunk in enumerate(keys):
|
||||
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
|
||||
starts, ends, block_ids,
|
||||
layer_id)
|
||||
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
|
||||
req_meta) # type: ignore[union-attr, call-arg, arg-type]
|
||||
yield
|
||||
else:
|
||||
for layer_id in range(self.num_layers):
|
||||
yield
|
||||
logger.debug(
|
||||
f"Stored {num_stored_tokens} out of total {len(tokens)} tokens")
|
||||
|
||||
def get_finished(self) -> tuple[set[str], set[str]]:
|
||||
done_sending = (
|
||||
self.kv_send_thread.
|
||||
get_and_clear_finished_requests( # type: ignore[union-attr]
|
||||
) if self.kv_role in ['kv_producer', 'kv_both'] else set())
|
||||
done_recving = self.kv_recv_thread.get_and_clear_finished_requests( # type: ignore[union-attr]
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Number of completed KV cache send requests: %d, receive "
|
||||
"requests: %d, tp_rank:%d", len(done_sending), len(done_recving),
|
||||
self.tp_rank)
|
||||
return done_sending, done_recving
|
||||
|
||||
def wait_layer_transfer_finish(self):
|
||||
time.sleep(10)
|
||||
pass
|
||||
|
||||
def lookup(
|
||||
self,
|
||||
tokens: Union[torch.Tensor, List[int]],
|
||||
use_layerwise: bool,
|
||||
) -> int:
|
||||
"""
|
||||
Checks the existence of KV cache of the tokens from the cache engine.
|
||||
|
||||
:param tokens: the input tokens, with shape [seq_len]
|
||||
|
||||
:return: An int indicating how many prefix tokens are cached.
|
||||
"""
|
||||
end = 0
|
||||
|
||||
for start, end, key in self.token_database.process_tokens(tokens):
|
||||
try:
|
||||
if use_layerwise:
|
||||
keys = []
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
for key in keys_multi_layer:
|
||||
keys.append(key.to_string())
|
||||
# batch is_exists
|
||||
ress = self.m_store.batch_exists(keys)
|
||||
res = 1
|
||||
for value in ress:
|
||||
if value != 1:
|
||||
res = 0
|
||||
break
|
||||
else:
|
||||
res = self.m_store.exists(key)
|
||||
if res == 1:
|
||||
continue
|
||||
else:
|
||||
return start
|
||||
except Exception as e:
|
||||
logger.warning(f"Remote connection failed in contains: {e}")
|
||||
return start
|
||||
|
||||
# all tokens where found, return the maximal end
|
||||
return end
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the cache engine and free all the resources"""
|
||||
self.m_store.close()
|
||||
88
vllm_ascend/distributed/mooncake/mooncake_store.py
Normal file
88
vllm_ascend/distributed/mooncake/mooncake_store.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Standard
|
||||
import os
|
||||
|
||||
# Third Party
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
|
||||
from vllm.utils import logger
|
||||
|
||||
from vllm_ascend.distributed.mooncake.config_data import MooncakeEngineKey
|
||||
|
||||
from .config_data import MooncakeStoreConfig
|
||||
|
||||
METADATA_BYTES_LEN = 24
|
||||
BASE_PORT = int(os.getenv("VLLM_BASE_PORT", "8790"))
|
||||
|
||||
|
||||
class Mooncakestore():
|
||||
|
||||
def __init__(self, parallel_config: ParallelConfig):
|
||||
try:
|
||||
from mooncake.store import MooncakeDistributedStore # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install mooncake by following the instructions at "
|
||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
||||
"to run vLLM with MooncakeConnector.") from e
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = parallel_config.tensor_parallel_size
|
||||
dp_rank = parallel_config.data_parallel_rank_local
|
||||
all_device_ids = os.getenv("ASCEND_RT_VISIBLE_DEVICES", None)
|
||||
if not all_device_ids:
|
||||
device_ids_list = list(
|
||||
range(dp_rank * tp_size, (dp_rank + 1) * tp_size))
|
||||
else:
|
||||
device_ids_list = list(map(int, all_device_ids.split(',')))
|
||||
assert len(device_ids_list) > tp_rank
|
||||
device_id = device_ids_list[tp_rank]
|
||||
self.config = MooncakeStoreConfig.load_from_env()
|
||||
if self.config.protocol == "ascend":
|
||||
local_hostname = self.config.local_hostname + ":" + str(BASE_PORT + int(device_id)) + \
|
||||
":npu_" + str(device_id)
|
||||
else:
|
||||
local_hostname = self.config.local_hostname
|
||||
self.store = MooncakeDistributedStore()
|
||||
ret = self.store.setup(local_hostname, self.config.metadata_server,
|
||||
self.config.global_segment_size,
|
||||
self.config.local_buffer_size,
|
||||
self.config.protocol, self.config.device_name,
|
||||
self.config.master_server_address)
|
||||
if ret != 0:
|
||||
msg = "Initialize mooncake failed."
|
||||
logger.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
def set_kv_caches(self, kvcache):
|
||||
self.kvcache = list(kvcache)
|
||||
|
||||
def exists(self, key: MooncakeEngineKey) -> bool:
|
||||
return self.store.is_exist(key.to_string()) == 1
|
||||
|
||||
def batch_exists(self, keys: list[str]) -> list[bool]:
|
||||
return self.store.batch_is_exist(keys)
|
||||
|
||||
def get(self, key: MooncakeEngineKey, addr: list[int], size: list[int]):
|
||||
expect_res = sum(size)
|
||||
key_str = key.to_string()
|
||||
try:
|
||||
res = self.store.batch_get_into_ascend(key_str, addr, size)
|
||||
if res[0] != expect_res:
|
||||
logger.error(f"Failed to get key: [{key_str}] .")
|
||||
except Exception:
|
||||
logger.error(f"Failed to get key: [{key_str}] .")
|
||||
return res
|
||||
|
||||
def put(self, key: MooncakeEngineKey, addr: list[int], size: list[int]):
|
||||
key_str = key.to_string()
|
||||
try:
|
||||
ret = self.store.batch_put_from_ascend(key_str, addr, size)
|
||||
if ret[0] != 0:
|
||||
logger.error(f"Failed to put key {key_str}.")
|
||||
except Exception:
|
||||
logger.error(f"Failed to put key {key_str}.")
|
||||
|
||||
return ret
|
||||
|
||||
def close(self):
|
||||
self.store.close()
|
||||
logger.info("Closed the mooncake store connection")
|
||||
484
vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py
Normal file
484
vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py
Normal file
@@ -0,0 +1,484 @@
|
||||
import threading
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import vllm.envs as envs
|
||||
import zmq
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.utils import logger, make_zmq_socket
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
|
||||
from vllm_ascend.distributed.mooncake.config_data import (
|
||||
LoadSpec, MooncakeConnectorMetadata, ReqMeta, RequestTracker)
|
||||
from vllm_ascend.distributed.mooncake.mooncake_engine import MooncakeEngine
|
||||
|
||||
|
||||
class MooncakeConnectorV1(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
|
||||
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"use_layerwise", False)
|
||||
|
||||
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
|
||||
self.sended_but_unfinished_reqs: set[str] = set()
|
||||
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler = MooncakeStoreConnectorV1Scheduler(
|
||||
vllm_config, self.use_layerwise)
|
||||
else:
|
||||
self.connector_worker = MooncakeEngine(
|
||||
vllm_config,
|
||||
self.use_layerwise,
|
||||
)
|
||||
|
||||
assert self.connector_worker is not None
|
||||
if vllm_config.parallel_config.rank == 0:
|
||||
self.lookup_server = MooncakeLookupServer(
|
||||
self.connector_worker, vllm_config, self.use_layerwise)
|
||||
|
||||
############################################################
|
||||
# Scheduler Side Methods
|
||||
############################################################
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
|
||||
############################################################
|
||||
# Worker Side Methods
|
||||
############################################################
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._get_connector_metadata(),
|
||||
MooncakeConnectorMetadata)
|
||||
self.connector_worker.start_load_kv(self._get_connector_metadata())
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""MooncakeStoreConnector does not do layerwise saving."""
|
||||
if not self.use_layerwise:
|
||||
return
|
||||
self.connector_worker.wait_for_layer_load()
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
"""MooncakeStoreConnector does not save explicitly."""
|
||||
if not self.use_layerwise:
|
||||
return
|
||||
|
||||
if self.kv_role == "kv_consumer":
|
||||
# Don't do save if the role is kv_consumer
|
||||
return
|
||||
self.connector_worker.save_kv_layer(self._get_connector_metadata())
|
||||
|
||||
def wait_for_save(self):
|
||||
"""MooncakeStoreConnector does not save explicitly."""
|
||||
if self.kv_role == "kv_consumer":
|
||||
# Don't do save if the role is kv_consumer
|
||||
return
|
||||
|
||||
if self.use_layerwise:
|
||||
self.connector_worker.wait_layer_transfer_finish()
|
||||
return
|
||||
|
||||
self.connector_worker.wait_for_save(self._get_connector_metadata())
|
||||
|
||||
def get_finished(self,
|
||||
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
"""Get the finished recving and sending requests."""
|
||||
assert self.connector_worker is not None
|
||||
meta = self._get_connector_metadata()
|
||||
done_sending, done_recving = self.connector_worker.get_finished()
|
||||
sended_and_finished: set[str] = set()
|
||||
for item in list(self.sended_but_unfinished_reqs):
|
||||
if item not in meta.unfinished_request_ids:
|
||||
sended_and_finished.add(item)
|
||||
self.sended_but_unfinished_reqs.remove(item)
|
||||
for item in done_sending:
|
||||
if item in meta.unfinished_request_ids:
|
||||
self.sended_but_unfinished_reqs.add(item)
|
||||
else:
|
||||
sended_and_finished.add(item)
|
||||
|
||||
return sended_and_finished, done_recving
|
||||
|
||||
|
||||
def get_zmq_rpc_path_mooncake(
|
||||
vllm_config: Optional["VllmConfig"] = None, ) -> str:
|
||||
base_url = envs.VLLM_RPC_BASE_PATH
|
||||
# Default to 0 if not configured
|
||||
rpc_port = 0
|
||||
if vllm_config is not None:
|
||||
rpc_port = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"mooncake_rpc_port", 0)
|
||||
logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port)
|
||||
return f"ipc://{base_url}/mooncake_rpc_port_{rpc_port}"
|
||||
|
||||
|
||||
class MooncakeStoreConnectorV1Scheduler:
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", use_layerwise):
|
||||
self.client = MooncakeLookupClient(vllm_config)
|
||||
self.use_layerwise = use_layerwise
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
# request_id -> (vllm cached tokes, mooncake cached tokens)
|
||||
self.load_specs: dict[str, LoadSpec] = {}
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
# request_id -> full_token_ids
|
||||
self._request_trackers: dict[str, RequestTracker] = {}
|
||||
# Whether to discard partial chunks
|
||||
self._discard_partial_chunks = (
|
||||
vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"discard_partial_chunks", True))
|
||||
self._unfinished_requests: dict[str, tuple[Request, list[int]]] = {}
|
||||
self._unfinished_request_ids: set[str] = set()
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Check for external KV cache hit.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
|
||||
if self._discard_partial_chunks:
|
||||
token_block_end = len(request.prompt_token_ids
|
||||
) // self._block_size * self._block_size
|
||||
token_ids = torch.tensor(
|
||||
request.prompt_token_ids[:token_block_end])
|
||||
else:
|
||||
token_ids = torch.tensor(request.prompt_token_ids)
|
||||
|
||||
num_external_hit_tokens = self.client.lookup(token_ids)
|
||||
|
||||
if num_external_hit_tokens == request.num_tokens:
|
||||
num_external_hit_tokens -= 1
|
||||
|
||||
need_to_allocate = num_external_hit_tokens - num_computed_tokens
|
||||
|
||||
logger.info(
|
||||
"Reqid: %s, Total tokens %d, mooncake hit tokens: %d, need to load: %d",
|
||||
request.request_id,
|
||||
request.num_tokens,
|
||||
num_external_hit_tokens,
|
||||
need_to_allocate,
|
||||
)
|
||||
|
||||
if need_to_allocate <= 0:
|
||||
return 0, False
|
||||
|
||||
self.load_specs[request.request_id] = LoadSpec(
|
||||
vllm_cached_tokens=num_computed_tokens,
|
||||
mooncake_cached_tokens=num_external_hit_tokens,
|
||||
can_load=False,
|
||||
)
|
||||
|
||||
return need_to_allocate, not self.use_layerwise
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
"""
|
||||
Update KVConnector state after temporary buffer alloc.
|
||||
|
||||
For SharedStorageConnector, update _request_needs_load
|
||||
if the CacheManager this allocated blocks for us.
|
||||
"""
|
||||
local_block_ids = []
|
||||
if num_external_tokens > 0:
|
||||
local_block_ids = blocks.get_block_ids()[0]
|
||||
|
||||
self._unfinished_requests[request.request_id] = (request,
|
||||
local_block_ids)
|
||||
self._unfinished_request_ids.add(request.request_id)
|
||||
if request.request_id not in self.load_specs:
|
||||
# No KV tokens from external KV cache, return
|
||||
return
|
||||
|
||||
if num_external_tokens == 0:
|
||||
# No need to load anything
|
||||
self.load_specs[request.request_id].can_load = False
|
||||
return
|
||||
|
||||
assert (
|
||||
num_external_tokens > 0 and num_external_tokens
|
||||
== self.load_specs[request.request_id].mooncake_cached_tokens -
|
||||
self.load_specs[request.request_id].vllm_cached_tokens
|
||||
), (f"Mismatch in number of tokens: {num_external_tokens} vs "
|
||||
f"{self.load_specs[request.request_id].mooncake_cached_tokens} - "
|
||||
f"{self.load_specs[request.request_id].vllm_cached_tokens}"
|
||||
f" for request {request.request_id}")
|
||||
|
||||
self.load_specs[request.request_id].can_load = True
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
"""Attach the connector metadata to the request object.
|
||||
|
||||
This function should NOT modify other fields in the scheduler_output
|
||||
except the `kv_connector_metadata` field.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
|
||||
force_skip_save = self.kv_role == "kv_consumer"
|
||||
|
||||
for finished_req_id in scheduler_output.finished_req_ids:
|
||||
self._request_trackers.pop(finished_req_id, None)
|
||||
self._unfinished_requests.pop(finished_req_id, None)
|
||||
self._unfinished_request_ids.remove(finished_req_id)
|
||||
|
||||
meta = MooncakeConnectorMetadata(self._unfinished_request_ids)
|
||||
|
||||
for request in scheduler_output.scheduled_new_reqs:
|
||||
# Right now, we only load KV for new requests
|
||||
load_spec = self.load_specs.pop(request.req_id, None)
|
||||
num_tokens_to_compute = (
|
||||
request.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[request.req_id])
|
||||
request_tracker = RequestTracker.from_new_request(
|
||||
request, num_tokens_to_compute)
|
||||
self._request_trackers[request.req_id] = request_tracker
|
||||
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
|
||||
self._block_size * self._block_size)
|
||||
if self._discard_partial_chunks else len(
|
||||
request.prompt_token_ids))
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
self._block_size,
|
||||
load_spec=load_spec,
|
||||
skip_save=force_skip_save,
|
||||
is_last_chunk=len(request_tracker.token_ids)
|
||||
>= last_chunk_tokens_num,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
)
|
||||
if req_meta is not None:
|
||||
meta.add_request(req_meta)
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
if isinstance(cached_reqs, list) and not force_skip_save:
|
||||
for i, req in enumerate(cached_reqs):
|
||||
request_tracker = self._request_trackers[req.req_id]
|
||||
request_tracker.update(req.new_token_ids, req.new_block_ids)
|
||||
last_chunk_tokens_num = ((len(req.prompt_token_ids) //
|
||||
self._block_size * self._block_size)
|
||||
if self._discard_partial_chunks else
|
||||
len(req.prompt_token_ids))
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
self._block_size,
|
||||
load_spec=None,
|
||||
skip_save=force_skip_save,
|
||||
is_last_chunk=len(request_tracker.token_ids)
|
||||
>= last_chunk_tokens_num,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
)
|
||||
if req_meta is not None:
|
||||
meta.add_request(req_meta)
|
||||
elif not force_skip_save:
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
request_tracker = self._request_trackers[req_id]
|
||||
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
req_tuple = self._unfinished_requests.get(req_id)
|
||||
if req_tuple:
|
||||
request = req_tuple[0]
|
||||
num_current_tokens = len(request_tracker.token_ids)
|
||||
new_token_ids = request.all_token_ids[
|
||||
num_current_tokens:num_current_tokens + num_new_tokens]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Request {req_id} is not in _unfinished_requests, "
|
||||
f"but it is scheduled to be cached")
|
||||
new_block_ids = cached_reqs.new_block_ids[i]
|
||||
if not new_block_ids:
|
||||
continue
|
||||
request_tracker.update(new_token_ids, new_block_ids)
|
||||
# decode not save
|
||||
if len(request_tracker.token_ids) > len(
|
||||
request.prompt_token_ids):
|
||||
continue
|
||||
|
||||
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
|
||||
self._block_size * self._block_size)
|
||||
if self._discard_partial_chunks else
|
||||
len(request.prompt_token_ids))
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
self._block_size,
|
||||
load_spec=None,
|
||||
skip_save=force_skip_save,
|
||||
is_last_chunk=len(request_tracker.token_ids)
|
||||
>= last_chunk_tokens_num,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
)
|
||||
if req_meta is not None:
|
||||
meta.add_request(req_meta)
|
||||
|
||||
request_ids = [
|
||||
req.req_id for req in scheduler_output.scheduled_new_reqs
|
||||
]
|
||||
for request_id, (request,
|
||||
block_ids) in self._unfinished_requests.items():
|
||||
if request_id not in request_ids and request_id not in cached_reqs.req_ids:
|
||||
load_spec = self.load_specs.pop(request_id, None)
|
||||
if not load_spec:
|
||||
continue
|
||||
num_tokens_to_compute = load_spec.mooncake_cached_tokens
|
||||
if (num_tokens_to_compute % self._block_size
|
||||
!= 0) and (num_tokens_to_compute
|
||||
== len(request.prompt_token_ids) - 1):
|
||||
num_tokens_to_compute = num_tokens_to_compute + 1
|
||||
request_tracker = RequestTracker(
|
||||
req_id=request_id,
|
||||
token_ids=request.prompt_token_ids[:num_tokens_to_compute].
|
||||
copy(),
|
||||
allocated_block_ids=block_ids,
|
||||
num_saved_tokens=0,
|
||||
)
|
||||
|
||||
self._request_trackers[request_id] = request_tracker
|
||||
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
self._block_size,
|
||||
load_spec=load_spec,
|
||||
skip_save=None,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
)
|
||||
if req_meta is not None:
|
||||
meta.add_request(req_meta)
|
||||
return meta
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
"""
|
||||
Once a request is finished, determine whether request blocks
|
||||
should be freed now or will be sent asynchronously and freed later.
|
||||
"""
|
||||
if self.kv_role == "kv_consumer":
|
||||
return False, None
|
||||
if self._request_trackers[request.request_id].num_saved_tokens <= 0:
|
||||
return False, None
|
||||
delay_free_blocks = len(block_ids) > 0
|
||||
if delay_free_blocks:
|
||||
logger.info("Delaying free of %d blocks for request %s",
|
||||
len(block_ids), request.request_id)
|
||||
return delay_free_blocks, None
|
||||
|
||||
|
||||
class MooncakeLookupClient:
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig"):
|
||||
self.encoder = MsgpackEncoder()
|
||||
self.ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
socket_path = get_zmq_rpc_path_mooncake(vllm_config)
|
||||
self.socket = make_zmq_socket(
|
||||
self.ctx,
|
||||
socket_path,
|
||||
zmq.REQ, # type: ignore[attr-defined]
|
||||
bind=False,
|
||||
)
|
||||
|
||||
def lookup(self, token_ids: torch.Tensor) -> int:
|
||||
request = self.encoder.encode(token_ids)
|
||||
self.socket.send_multipart(request, copy=False)
|
||||
resp = self.socket.recv()
|
||||
result = int.from_bytes(resp, "big")
|
||||
return result
|
||||
|
||||
def close(self):
|
||||
self.socket.close(linger=0)
|
||||
|
||||
|
||||
class MooncakeLookupServer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mooncake_engine: MooncakeEngine,
|
||||
vllm_config: "VllmConfig",
|
||||
use_layerwise: bool,
|
||||
):
|
||||
self.decoder = MsgpackDecoder(torch.Tensor)
|
||||
self.ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
socket_path = get_zmq_rpc_path_mooncake(vllm_config)
|
||||
self.socket = make_zmq_socket(
|
||||
self.ctx,
|
||||
socket_path,
|
||||
zmq.REP, # type: ignore[attr-defined]
|
||||
bind=True,
|
||||
)
|
||||
|
||||
self.mooncake_engine = mooncake_engine
|
||||
self.running = True
|
||||
|
||||
def process_request():
|
||||
while self.running:
|
||||
frames = self.socket.recv_multipart(copy=False)
|
||||
token_ids = self.decoder.decode(frames)
|
||||
result = self.mooncake_engine.lookup(token_ids, use_layerwise)
|
||||
response = result.to_bytes(4, "big")
|
||||
self.socket.send(response)
|
||||
|
||||
self.thread = threading.Thread(target=process_request, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def close(self):
|
||||
self.socket.close(linger=0)
|
||||
# TODO: close the thread!
|
||||
@@ -11,7 +11,7 @@ from collections import defaultdict, deque
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, OrderedDict, Tuple
|
||||
|
||||
import msgspec
|
||||
import numpy as np
|
||||
@@ -19,6 +19,7 @@ import numpy.typing as npt
|
||||
import torch
|
||||
import zmq
|
||||
from mooncake.engine import TransferEngine # type: ignore
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
@@ -29,6 +30,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
@@ -67,12 +69,16 @@ class KVCacheTaskTracker:
|
||||
# intentionally delayed. Each entry is a tuple of (request_id,
|
||||
# timestamp). If a request remains in this queue for too long, it will
|
||||
# be force-freed.
|
||||
self.delayed_free_requests: deque[Tuple[str, float]] = deque()
|
||||
self.record_finished_requests: set[str] = set()
|
||||
self.delayed_free_requests: OrderedDict[str, float] = OrderedDict()
|
||||
|
||||
def update_done_task_count(self, request_id: str):
|
||||
with self.done_task_lock:
|
||||
self.finished_requests.add(request_id)
|
||||
self._remove_delayed_requests(request_id)
|
||||
if request_id in self.delayed_free_requests:
|
||||
self._remove_delayed_requests(request_id)
|
||||
else:
|
||||
self.record_finished_requests.add(request_id)
|
||||
|
||||
def get_and_clear_finished_requests(self) -> set[str]:
|
||||
"""
|
||||
@@ -90,7 +96,10 @@ class KVCacheTaskTracker:
|
||||
def add_delayed_request(self, request_id: str, delay_start_time: float):
|
||||
"""Add a delayed free request."""
|
||||
with self.done_task_lock:
|
||||
self.delayed_free_requests.append((request_id, delay_start_time))
|
||||
if request_id not in self.record_finished_requests:
|
||||
self.delayed_free_requests[request_id] = delay_start_time
|
||||
else:
|
||||
self.record_finished_requests.discard(request_id)
|
||||
|
||||
def _retrieve_expired_requests(self):
|
||||
"""Retrieve all expired delayed requests."""
|
||||
@@ -98,10 +107,11 @@ class KVCacheTaskTracker:
|
||||
# Free delayed requests if they exceed the timeout
|
||||
current_time = time.time()
|
||||
while self.delayed_free_requests:
|
||||
request_id, delay_start_time = self.delayed_free_requests[0]
|
||||
request_id = next(iter(self.delayed_free_requests))
|
||||
delay_start_time = self.delayed_free_requests[request_id]
|
||||
if (current_time - delay_start_time
|
||||
> envs_ascend.VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT):
|
||||
self.delayed_free_requests.popleft()
|
||||
> envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT):
|
||||
self.delayed_free_requests.popitem(last=False)
|
||||
expired_requests.add(request_id)
|
||||
logger.info("Force freed request: %s", request_id)
|
||||
else:
|
||||
@@ -110,8 +120,7 @@ class KVCacheTaskTracker:
|
||||
|
||||
def _remove_delayed_requests(self, request_id: str):
|
||||
"""Remove all delayed free requests matching the given request_id."""
|
||||
self.delayed_free_requests = deque(
|
||||
(r, t) for r, t in self.delayed_free_requests if r != request_id)
|
||||
self.delayed_free_requests.pop(request_id)
|
||||
|
||||
|
||||
class KVCacheSendingThread(threading.Thread):
|
||||
@@ -230,6 +239,7 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
self.block_len = block_len
|
||||
# TODO(jianzs): find a better way to detect MLA.
|
||||
self.use_mla = len(block_len) == 2
|
||||
self.use_sfa = len(block_len) == 3
|
||||
|
||||
self.request_queue: queue.Queue[Any] = queue.Queue()
|
||||
# TODO(jianzs): make this configurable
|
||||
@@ -341,8 +351,12 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
src_list, dst_list, length_list = [], [], []
|
||||
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
|
||||
zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)):
|
||||
block_len = (self.block_len[k % 2]
|
||||
if self.use_mla else self.block_len[0])
|
||||
if self.use_mla:
|
||||
block_len = (self.block_len[k % 2])
|
||||
elif self.use_sfa:
|
||||
block_len = (self.block_len[k % 3])
|
||||
else:
|
||||
block_len = (self.block_len[0])
|
||||
for i, remote_block_id in enumerate(grouped_remote_block_ids):
|
||||
local_block_ids = grouped_local_block_ids[i]
|
||||
src = src_layer_base_addr + local_block_ids[0] * block_len
|
||||
@@ -559,6 +573,7 @@ class MooncakeConnectorScheduler:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||
self.vllm_config = vllm_config
|
||||
self.ascend_config = get_ascend_config()
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.engine_id = engine_id
|
||||
logger.info("Initializing Mooncake Scheduler %s", engine_id)
|
||||
@@ -718,7 +733,7 @@ class MooncakeConnectorScheduler:
|
||||
assert "tp_size" in decode_parallel_config.keys()
|
||||
self._decode_tp_size = decode_parallel_config["tp_size"]
|
||||
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa:
|
||||
return self._decode_tp_size
|
||||
else:
|
||||
# TODO support mha and gqa
|
||||
@@ -782,10 +797,12 @@ class MooncakeConnectorWorker:
|
||||
assert len(device_ids) > self.tp_rank # type: ignore
|
||||
self.device_id = device_ids[self.tp_rank] # type: ignore
|
||||
|
||||
self._initialize(
|
||||
hostname=self.side_channel_host + ':' + '0' + ':' + 'npu_' \
|
||||
+ str(self.device_id),
|
||||
device_name=None)
|
||||
if vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
'use_ascend_direct', False):
|
||||
hostname = self.side_channel_host
|
||||
else:
|
||||
hostname = f"{self.side_channel_host}:0:npu_{self.device_id}"
|
||||
self._initialize(hostname=hostname, device_name=None)
|
||||
self.te_rpc_port = self.engine.get_rpc_port()
|
||||
|
||||
# Background thread for sending or receiving KV caches.
|
||||
@@ -837,7 +854,9 @@ class MooncakeConnectorWorker:
|
||||
|
||||
# TODO(tms): Find a more robust way to detect and handle MLA
|
||||
self.use_mla = first_kv_cache_tuple[0].size(
|
||||
-1) != first_kv_cache_tuple[1].size(-1)
|
||||
-1) != first_kv_cache_tuple[1].size(-1) and len(
|
||||
first_kv_cache_tuple) == 2
|
||||
self.use_sfa = len(first_kv_cache_tuple) == 3
|
||||
if self.use_mla:
|
||||
# MLA case.[num_block, block_size, 1, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
@@ -851,6 +870,21 @@ class MooncakeConnectorWorker:
|
||||
logger.info(
|
||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
||||
self.num_blocks, block_shape_norm, block_shape_pe)
|
||||
elif self.use_sfa:
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 3 # [block_size, latent_dim]
|
||||
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
||||
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
||||
block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:]
|
||||
self.block_len = [
|
||||
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
||||
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
|
||||
first_kv_cache[2].element_size() * math.prod(block_shape_k)
|
||||
]
|
||||
logger.info(
|
||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s",
|
||||
self.num_blocks, block_shape_norm, block_shape_pe,
|
||||
block_shape_k)
|
||||
else:
|
||||
# [num_block, block_size, num_head, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
@@ -861,8 +895,9 @@ class MooncakeConnectorWorker:
|
||||
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
|
||||
block_shape)
|
||||
|
||||
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
|
||||
self.use_mla, first_kv_cache.shape)
|
||||
logger.info(
|
||||
"Registering KV_Caches. use_mla: %s, use_sfa: %s, shape %s",
|
||||
self.use_mla, self.use_sfa, first_kv_cache.shape)
|
||||
|
||||
self.kv_caches = kv_caches
|
||||
kv_caches_base_addr = []
|
||||
@@ -874,9 +909,16 @@ class MooncakeConnectorWorker:
|
||||
region_len = self.num_blocks * self.block_len[i % 2]
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
self._register(base_addr, region_len)
|
||||
elif self.use_sfa:
|
||||
for i, cache in enumerate(cache_or_caches, 0):
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[i % 3]
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
self._register(base_addr, region_len)
|
||||
else:
|
||||
cache_list = [cache_or_caches
|
||||
] if self.use_mla else cache_or_caches
|
||||
cache_list = [
|
||||
cache_or_caches
|
||||
] if self.use_mla or self.use_sfa else cache_or_caches
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[0]
|
||||
|
||||
@@ -11,7 +11,7 @@ from vllm_ascend.ascend_config import get_ascend_config
|
||||
# Currently, mc2 op need their own group coordinator.
|
||||
_MC2: Optional[GroupCoordinator] = None
|
||||
_MLP_TP: Optional[GroupCoordinator] = None
|
||||
|
||||
_OTP: Optional[GroupCoordinator] = None
|
||||
_LMTP: Optional[GroupCoordinator] = None
|
||||
|
||||
|
||||
@@ -20,6 +20,12 @@ def get_mc2_group() -> GroupCoordinator:
|
||||
return _MC2
|
||||
|
||||
|
||||
def get_otp_group() -> GroupCoordinator:
|
||||
assert _OTP is not None, (
|
||||
"output tensor parallel group is not initialized")
|
||||
return _OTP
|
||||
|
||||
|
||||
def get_lmhead_tp_group() -> GroupCoordinator:
|
||||
assert _LMTP is not None, (
|
||||
"lm head tensor parallel group is not initialized")
|
||||
@@ -74,6 +80,20 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
||||
backend,
|
||||
group_name="mlp_tp")
|
||||
|
||||
# If oproj tensor parallel size is set, we will create a group for it.
|
||||
otp_size = get_ascend_config().oproj_tensor_parallel_size
|
||||
if otp_size is not None:
|
||||
group_ranks = []
|
||||
global _OTP
|
||||
num_oproj_tensor_parallel_groups: int = (world_size // otp_size)
|
||||
for i in range(num_oproj_tensor_parallel_groups):
|
||||
ranks = list(range(i * otp_size, (i + 1) * otp_size))
|
||||
group_ranks.append(ranks)
|
||||
_OTP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="otp")
|
||||
|
||||
lmhead_tensor_parallel_size = get_ascend_config(
|
||||
).lmhead_tensor_parallel_size
|
||||
if lmhead_tensor_parallel_size is not None:
|
||||
@@ -117,3 +137,8 @@ def destroy_ascend_model_parallel():
|
||||
if _LMTP:
|
||||
_LMTP.destroy()
|
||||
_LMTP = None
|
||||
|
||||
global _OTP
|
||||
if _OTP:
|
||||
_OTP.destroy()
|
||||
_OTP = None
|
||||
|
||||
@@ -1,248 +0,0 @@
|
||||
# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Adapts from: Megatron/megatron/core/tensor_parallel/mappings.py.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
import torch
|
||||
|
||||
|
||||
def _gather_along_first_dim(input_, group, output_split_sizes=None):
|
||||
"""Gather tensors and concatenate along the first dimension.
|
||||
|
||||
Args:
|
||||
input_tensor (torch.Tensor):
|
||||
A tensor to be gathered.
|
||||
output_split_sizes (List[int], optional):
|
||||
A list specifying the sizes of the output splits along the first dimension.
|
||||
If None, equal splitting is assumed. Default: None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Gathered tensor.
|
||||
"""
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
dim_size = list(input_.size())
|
||||
if output_split_sizes is None:
|
||||
dim_size[0] = dim_size[0] * world_size
|
||||
|
||||
output = torch.empty(dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device())
|
||||
torch.distributed.all_gather_into_tensor(output,
|
||||
input_.contiguous(),
|
||||
group=group)
|
||||
else:
|
||||
dim_size[0] = sum(output_split_sizes)
|
||||
output = torch.empty(dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device())
|
||||
output_tensor_list = list(
|
||||
torch.split(output, output_split_sizes, dim=0))
|
||||
torch.distributed.all_gather(output_tensor_list, input_, group=group)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _gather_along_last_dim(input_, group):
|
||||
"""Gather tensors and concatenate along the last dimension."""
|
||||
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
dim_size = list(input_.size())
|
||||
dim_size[0] = dim_size[0] * world_size
|
||||
|
||||
output = torch.empty(dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device())
|
||||
torch.distributed.all_gather_into_tensor(output,
|
||||
input_.contiguous(),
|
||||
group=group)
|
||||
tensor_list = output.chunk(world_size, dim=0)
|
||||
output = torch.cat(tensor_list, dim=-1).contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _reduce_scatter_along_first_dim(input_,
|
||||
group,
|
||||
input_split_sizes=None,
|
||||
use_global_buffer=False):
|
||||
"""Reduce-scatter the input tensor across model parallel group.
|
||||
|
||||
Args:
|
||||
input_ (torch.Tensor): The input tensor to be reduce-scattered.
|
||||
input_split_sizes (List[int], optional): A list specifying the sizes of
|
||||
the input splits along the first dimension for each rank. If None,
|
||||
equal splitting is assumed. Default: None.
|
||||
"""
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
if input_split_sizes is None:
|
||||
dim_size = list(input_.size())
|
||||
assert (
|
||||
dim_size[0] % world_size == 0
|
||||
), "First dimension of the tensor should be divisible by tensor parallel size"
|
||||
|
||||
dim_size[0] = dim_size[0] // world_size
|
||||
|
||||
output = torch.empty(dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device())
|
||||
torch.distributed.reduce_scatter_tensor(output,
|
||||
input_.contiguous(),
|
||||
group=group)
|
||||
else:
|
||||
rank = torch.distributed.get_rank(group)
|
||||
input_tensor_list = list(torch.split(input_, input_split_sizes, dim=0))
|
||||
|
||||
output = torch.empty_like(input_tensor_list[rank])
|
||||
torch.distributed.reduce_scatter(output,
|
||||
input_tensor_list,
|
||||
group=group)
|
||||
return output
|
||||
|
||||
|
||||
def _reduce_scatter_along_last_dim(input_, group):
|
||||
"""Reduce-scatter tensors on the last dimension."""
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
target_shape = list(input_.size())
|
||||
target_shape[-1] = target_shape[-1] // world_size
|
||||
input_ = input_.reshape(-1, input_.shape[-1])
|
||||
split_tensors = torch.split(input_,
|
||||
split_size_or_sections=input_.shape[-1] //
|
||||
world_size,
|
||||
dim=1)
|
||||
concat_tensor = torch.cat(split_tensors, dim=0)
|
||||
output = _reduce_scatter_along_first_dim(concat_tensor,
|
||||
group).reshape(target_shape)
|
||||
return output
|
||||
|
||||
|
||||
def all_gather_last_dim_from_tensor_parallel_region(input_, group):
|
||||
"""Wrapper for autograd function: forward: AG, backward RS <last dim>"""
|
||||
return _gather_along_last_dim(input_, group)
|
||||
|
||||
|
||||
def reduce_scatter_to_sequence_parallel_region(input_,
|
||||
group,
|
||||
input_split_sizes=None):
|
||||
"""Wrapper for autograd function: forward: RS, backward AG <first dim>"""
|
||||
return _reduce_scatter_along_first_dim(input_, group, input_split_sizes)
|
||||
|
||||
|
||||
def reduce_scatter_last_dim_to_tensor_parallel_region(input_, group):
|
||||
"""Wrapper for autograd function: forward: RS, backward AG: AG <last dim>"""
|
||||
return _reduce_scatter_along_last_dim(input_, group)
|
||||
|
||||
|
||||
def gather_from_sequence_parallel_region(
|
||||
input_,
|
||||
group,
|
||||
output_split_sizes=None,
|
||||
):
|
||||
"""Wrapper for autograd function: forward: AG, backward: RS <first dim>"""
|
||||
return _gather_along_first_dim(input_, group, output_split_sizes)
|
||||
|
||||
|
||||
def all_to_all(group, input, output_split_sizes=None, input_split_sizes=None):
|
||||
world_size = torch.distributed.get_world_size(group=group)
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input
|
||||
|
||||
input = input.contiguous()
|
||||
if output_split_sizes is None:
|
||||
# Equal split (all2all)
|
||||
output = torch.empty_like(input)
|
||||
else:
|
||||
# Unequal split (all2all-v)
|
||||
output = input.new_empty(
|
||||
size=[sum(output_split_sizes)] + list(input.size()[1:]),
|
||||
dtype=input.dtype,
|
||||
device=torch.npu.current_device(),
|
||||
)
|
||||
torch.distributed.all_to_all_single(
|
||||
output,
|
||||
input,
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=group,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def all_to_all_sp2hp(input_, group):
|
||||
"""
|
||||
Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape
|
||||
[num_tokens/TP, H] to [num_tokens, H/TP].
|
||||
|
||||
Args:
|
||||
input_ (torch.Tensor):
|
||||
The input tensor which has been distributed along the sequence
|
||||
dimension.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output tensor with shape [num_tokens, H/TP].
|
||||
|
||||
"""
|
||||
if group is None:
|
||||
return input_
|
||||
world_size = torch.distributed.get_world_size(group=group)
|
||||
tp_group = group
|
||||
input_ = input_.reshape(-1, input_.shape[-1])
|
||||
split_tensors = torch.split(input_,
|
||||
split_size_or_sections=input_.shape[-1] //
|
||||
world_size,
|
||||
dim=1)
|
||||
concat_tensor = torch.cat(split_tensors, dim=0)
|
||||
output = all_to_all(tp_group, concat_tensor)
|
||||
return output
|
||||
|
||||
|
||||
def all_to_all_hp2sp(input_, group):
|
||||
"""
|
||||
Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape
|
||||
[num_tokens, H/TP] to [num_tokens/TP, H].
|
||||
|
||||
Args:
|
||||
input_ (torch.Tensor):
|
||||
The input tensor which has been distributed along the hidden
|
||||
dimension.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output tensor with shape [num_tokens/TP, H].
|
||||
"""
|
||||
if group is None:
|
||||
return input_
|
||||
world_size = torch.distributed.get_world_size(group=group)
|
||||
input_ = input_.reshape(-1, input_.shape[-1])
|
||||
tp_group = group
|
||||
input_exchanged = all_to_all(tp_group, input_)
|
||||
input_reshaped = input_exchanged.reshape(-1, input_exchanged.shape[-1])
|
||||
split_tensors = torch.split(
|
||||
input_reshaped,
|
||||
split_size_or_sections=input_reshaped.shape[0] // world_size,
|
||||
dim=0)
|
||||
output = torch.cat(split_tensors, dim=-1)
|
||||
return output
|
||||
@@ -131,6 +131,26 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
# this feature is supported in A2, and eager mode will get better performance.
|
||||
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))),
|
||||
# Whether to enable FlashComm optimization when tensor parallel is enabled.
|
||||
# This feature will get better performance when concurrency is large.
|
||||
"VLLM_ASCEND_ENABLE_FLASHCOMM":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))),
|
||||
# Whether to enable MLP weight prefetch, only used in small concurrency.
|
||||
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
|
||||
# buffer size for gate up prefetch
|
||||
"VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE":
|
||||
lambda: int(
|
||||
os.getenv("VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)),
|
||||
# buffer size for down proj prefetch
|
||||
"VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE":
|
||||
lambda: int(
|
||||
os.getenv("VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)),
|
||||
# Whether to enable dense model and general optimizations for better performance.
|
||||
# Since we modified the base parent class `linear`, this optimization is also applicable to other model types.
|
||||
# However, there might be hidden issues, and it is currently recommended to prioritize its use with dense models.
|
||||
"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE", '0'))),
|
||||
# Whether to enable mlp optimize when tensor parallel is enabled.
|
||||
# this feature in eager mode will get better performance.
|
||||
"VLLM_ASCEND_ENABLE_MLP_OPTIMIZE":
|
||||
@@ -139,11 +159,16 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
# caused by the initialization of the Mooncake connector.
|
||||
"PHYSICAL_DEVICES":
|
||||
lambda: os.getenv("PHYSICAL_DEVICES", None),
|
||||
# Whether to enable msMonitor tool to monitor the performance of vllm-ascend.
|
||||
"MSMONITOR_USE_DAEMON":
|
||||
lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", '0'))),
|
||||
# Timeout (in seconds) for delayed KVCache block release. In the prefill
|
||||
# node, if a request is marked for delayed KV block release and the blocks
|
||||
# are not freed within this timeout, they will be forcibly released.
|
||||
"VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT":
|
||||
lambda: int(os.getenv("VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT", 250)),
|
||||
"VLLM_ASCEND_ENABLE_MLAPO":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLAPO", '0'))),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
@@ -157,4 +182,4 @@ def __getattr__(name: str):
|
||||
|
||||
|
||||
def __dir__():
|
||||
return list(env_variables.keys())
|
||||
return list(env_variables.keys())
|
||||
0
vllm_ascend/eplb/adaptor/__init__.py
Normal file
0
vllm_ascend/eplb/adaptor/__init__.py
Normal file
44
vllm_ascend/eplb/adaptor/abstract_adaptor.py
Normal file
44
vllm_ascend/eplb/adaptor/abstract_adaptor.py
Normal file
@@ -0,0 +1,44 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this adaptor.
|
||||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class EplbAdaptor():
|
||||
|
||||
def __init__(self, **args):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_rank_expert_workload(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_init_expert_map(self, num_moe_layers: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def do_update_expert_map(self, layer_id: Any,
|
||||
updated_expert_map: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def do_update_expert_weight(self, layer_id: Any,
|
||||
local_expert_to_replace: Any,
|
||||
buffer_tensor_id: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
289
vllm_ascend/eplb/adaptor/vllm_adaptor.py
Normal file
289
vllm_ascend/eplb/adaptor/vllm_adaptor.py
Normal file
@@ -0,0 +1,289 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this adaptor.
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.eplb.adaptor.abstract_adaptor import EplbAdaptor
|
||||
|
||||
|
||||
class VllmEplbAdaptor(EplbAdaptor):
|
||||
|
||||
def __init__(self, model, **args):
|
||||
super().__init__(**args)
|
||||
self.model = model
|
||||
self.rank_id = dist.get_rank()
|
||||
self.world_size = dist.get_world_size()
|
||||
self.param_dict = dict(self.model.named_parameters())
|
||||
if self.model.config.model_type == "qwen3_moe":
|
||||
self.num_dense_layers = 0
|
||||
self.global_expert_num = self.model.config.num_experts
|
||||
else:
|
||||
self.num_dense_layers = self.model.config.first_k_dense_replace
|
||||
self.global_expert_num = self.model.config.n_routed_experts
|
||||
self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers
|
||||
self.init_redundancy_expert = get_ascend_config(
|
||||
).init_redundancy_expert
|
||||
|
||||
# TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here
|
||||
if self.model.quant_config is not None:
|
||||
self.expert_weight_names = [
|
||||
"w13_weight", "w2_weight", "w13_weight_scale",
|
||||
"w13_weight_offset", "w2_weight_scale", "w2_weight_offset"
|
||||
]
|
||||
else:
|
||||
self.expert_weight_names = ["w13_weight", "w2_weight"]
|
||||
|
||||
self.expert_map_per_layer = dict(
|
||||
) # reference to expert map on device for expert map update
|
||||
self.expert_map_per_layer_cpu = dict(
|
||||
) # copy of expert map on CPU to avoid device synchronize frequently
|
||||
for layer_idx in range(self.num_moe_layers):
|
||||
self.expert_map_per_layer[self.num_dense_layers + layer_idx] = \
|
||||
self.model.get_expert_map(self.num_dense_layers + layer_idx)
|
||||
|
||||
# TODO: here we set number of buffer tensor equal to number of expert in each laryer, which can be improved
|
||||
num_buffer_tensor = torch.where(
|
||||
self.expert_map_per_layer[self.num_dense_layers] != -1)[0].numel()
|
||||
self.buffer_tensor_list: list[list[Any]] = [
|
||||
[] for _ in range(num_buffer_tensor)
|
||||
]
|
||||
self.init_buffer_tensor(num_buffer_tensor)
|
||||
|
||||
self.expert_param_per_layer = dict()
|
||||
self.init_expert_param_per_layer()
|
||||
|
||||
self.log2phy_map_per_layer = dict()
|
||||
for layer_idx in range(self.num_moe_layers):
|
||||
self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] = \
|
||||
self.model.get_log2phy_map(self.num_dense_layers + layer_idx)
|
||||
|
||||
self.all_topk_ids = []
|
||||
|
||||
def init_buffer_tensor(self, num_buffer_tensor):
|
||||
for name in self.expert_weight_names:
|
||||
complete_name = "model.layers." + str(
|
||||
self.num_dense_layers) + ".mlp.experts." + name
|
||||
expert_tensor = self.param_dict[complete_name].data[
|
||||
0:num_buffer_tensor]
|
||||
buffer_tensors = torch.empty_like(expert_tensor)
|
||||
for buffer_id in range(num_buffer_tensor):
|
||||
self.buffer_tensor_list[buffer_id].append(
|
||||
buffer_tensors[buffer_id])
|
||||
|
||||
def init_expert_param_per_layer(self):
|
||||
num_local_expert = self.param_dict["model.layers." + str(self.num_dense_layers) + \
|
||||
".mlp.experts." + self.expert_weight_names[0]].data.shape[0]
|
||||
for moe_layer_id in range(self.num_moe_layers):
|
||||
layer_idx = self.num_dense_layers + moe_layer_id
|
||||
self.expert_param_per_layer[layer_idx] = list()
|
||||
for local_expert_id in range(num_local_expert):
|
||||
self.expert_param_per_layer[layer_idx].append([
|
||||
self.param_dict["model.layers." + str(layer_idx) +
|
||||
".mlp.experts." +
|
||||
name].data[local_expert_id]
|
||||
for name in self.expert_weight_names
|
||||
])
|
||||
|
||||
def get_rank_expert_workload(self) -> torch.Tensor:
|
||||
self.moe_load = self.model.get_all_moe_loads()
|
||||
return self.moe_load
|
||||
|
||||
def get_init_expert_map(self, num_moe_layers):
|
||||
expert_map = self.model.get_all_expert_map(num_moe_layers)
|
||||
if dist.is_initialized():
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
gathered = torch.empty(
|
||||
(world_size, *expert_map.shape), # [W, L, E]
|
||||
dtype=expert_map.dtype,
|
||||
device=expert_map.device)
|
||||
|
||||
dist.all_gather_into_tensor(gathered, expert_map)
|
||||
all_maps = gathered.permute(1, 0, 2)
|
||||
all_expert_maps = all_maps.cpu()
|
||||
|
||||
for layer_idx in range(num_moe_layers):
|
||||
self.expert_map_per_layer_cpu[self.num_dense_layers + layer_idx] = \
|
||||
all_expert_maps[layer_idx][self.rank_id]
|
||||
|
||||
return all_expert_maps
|
||||
|
||||
def get_init_expert_map_from_file(self, num_moe_layers, expert_map_path):
|
||||
|
||||
try:
|
||||
expert_map_tensor, layers_num, ranks_num = self._expert_file_to_tensor(
|
||||
expert_map_path)
|
||||
expert_map_all = self.local2global(expert_map_tensor)
|
||||
except (TypeError, FileNotFoundError, OSError):
|
||||
expert_map_all = self.determine_expert_map_all()
|
||||
|
||||
for layer_idx in range(num_moe_layers):
|
||||
if self.model.config.model_type == "qwen3_moe":
|
||||
self.expert_map_per_layer_cpu[layer_idx] = \
|
||||
expert_map_all[layer_idx][self.rank_id]
|
||||
else:
|
||||
self.expert_map_per_layer_cpu[layer_idx + self.num_dense_layers] = \
|
||||
expert_map_all[layer_idx][self.rank_id]
|
||||
return expert_map_all
|
||||
|
||||
def _expert_file_to_tensor(self, expert_map_path: str):
|
||||
with open(expert_map_path, "r") as f:
|
||||
data = json.load(f)
|
||||
layers_num = data["moe_layer_count"]
|
||||
gpus_num = data["layer_list"][0]["device_count"]
|
||||
|
||||
tensor_data = []
|
||||
for layer in data["layer_list"]:
|
||||
device_data = []
|
||||
for device in layer["device_list"]:
|
||||
device_data.append(device["device_expert"])
|
||||
tensor_data.append(device_data)
|
||||
expert_map_tensor = torch.tensor(tensor_data, dtype=torch.int32)
|
||||
return expert_map_tensor, layers_num, gpus_num
|
||||
logger.error(f"failed to read expert_map_path: {expert_map_path}")
|
||||
|
||||
def _export_tensor_to_file(self, expert_maps, expert_map_record_path: str):
|
||||
if self.rank_id == 0:
|
||||
num_local_experts = expert_maps.max() + 1
|
||||
expert_maps_local = self.global2local(expert_maps,
|
||||
num_local_experts)
|
||||
|
||||
expert_maps_list = expert_maps_local.tolist()
|
||||
record: dict[str, Any] = {
|
||||
"moe_layer_count": len(expert_maps_list),
|
||||
"layer_list": []
|
||||
}
|
||||
|
||||
for layer_idx, layer_data in enumerate(expert_maps_list):
|
||||
layer_record: dict[str, Any] = {
|
||||
"layer_id": layer_idx,
|
||||
"device_count": len(layer_data),
|
||||
"device_list": []
|
||||
}
|
||||
|
||||
for device_idx, experts in enumerate(layer_data):
|
||||
device_record = {
|
||||
"device_id": device_idx,
|
||||
"device_expert": experts
|
||||
}
|
||||
layer_record["device_list"].append(device_record)
|
||||
|
||||
record["layer_list"].append(layer_record)
|
||||
|
||||
with open(expert_map_record_path, "w") as f:
|
||||
json.dump(record, f, indent=4)
|
||||
|
||||
def do_update_expert_map(self, layer_id, updated_expert_map):
|
||||
self.expert_map_per_layer[layer_id] = updated_expert_map.clone()
|
||||
self.expert_map_per_layer_cpu[layer_id] = updated_expert_map.clone()
|
||||
|
||||
def do_update_expert_weight(self, layer_id, local_expert_to_replace,
|
||||
buffer_tensor_id):
|
||||
for expert_tensor, buffer_tensor in zip(
|
||||
self.expert_param_per_layer[layer_id][local_expert_to_replace],
|
||||
self.buffer_tensor_list[buffer_tensor_id]):
|
||||
expert_tensor = buffer_tensor.clone()
|
||||
logger.debug(f"Expert tensor shape is :{expert_tensor.shape}")
|
||||
|
||||
def do_update_log2phy_map(self, layer_id, updated_log2phy_map):
|
||||
if self.log2phy_map_per_layer[layer_id] is not None:
|
||||
self.log2phy_map_per_layer[layer_id].copy_(updated_log2phy_map)
|
||||
|
||||
def global2local(self, placement: torch.Tensor,
|
||||
E_local: int) -> torch.Tensor:
|
||||
|
||||
L, G, _ = placement.shape
|
||||
device = placement.device
|
||||
|
||||
pt_local = torch.full((L, G, E_local),
|
||||
fill_value=-1,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
valid = placement >= 0
|
||||
l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True)
|
||||
|
||||
slot_idx = placement[l_idx, g_idx, k_idx]
|
||||
|
||||
pt_local[l_idx, g_idx, slot_idx] = k_idx
|
||||
|
||||
return pt_local
|
||||
|
||||
def local2global(self, placement_local: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
L, G, E_local = placement_local.shape
|
||||
device = placement_local.device
|
||||
|
||||
max_id = torch.max(placement_local)
|
||||
E_global = (max_id + 1).item() if max_id >= 0 else 0
|
||||
|
||||
if E_global == 0:
|
||||
return torch.empty((L, G, 0), dtype=torch.long, device=device)
|
||||
|
||||
placement_global = torch.full((L, G, E_global),
|
||||
fill_value=-1,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
valid = placement_local >= 0
|
||||
l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True)
|
||||
gid_idx = placement_local[l_idx, g_idx, slot_idx]
|
||||
|
||||
placement_global[l_idx, g_idx, gid_idx] = slot_idx
|
||||
|
||||
return placement_global
|
||||
|
||||
def determine_expert_map_all(self):
|
||||
if self.world_size == 1:
|
||||
local_ids = torch.arange(self.global_expert_num, dtype=torch.int32)
|
||||
return local_ids.view(1, 1, -1).expand(self.num_moe_layers, 1, -1)
|
||||
|
||||
local_num_experts = self.global_expert_num // self.world_size
|
||||
|
||||
expert_map_all = torch.full(
|
||||
(self.num_moe_layers, self.world_size, self.global_expert_num),
|
||||
-1,
|
||||
dtype=torch.int32)
|
||||
|
||||
for r in range(self.world_size):
|
||||
if r < self.world_size - 1:
|
||||
start = r * local_num_experts
|
||||
end = (r + 1) * local_num_experts
|
||||
local_count = local_num_experts
|
||||
else:
|
||||
start = r * local_num_experts
|
||||
end = self.global_expert_num
|
||||
local_count = self.global_expert_num - r * local_num_experts
|
||||
|
||||
if r < self.init_redundancy_expert:
|
||||
local_count += 1
|
||||
if end < self.global_expert_num:
|
||||
end += 1
|
||||
else:
|
||||
start -= 1
|
||||
|
||||
local_ids = torch.arange(local_count, dtype=torch.int32)
|
||||
expert_map_all[:, r, start:end] = local_ids.unsqueeze(0).expand(
|
||||
self.num_moe_layers, -1)
|
||||
|
||||
return expert_map_all
|
||||
0
vllm_ascend/eplb/core/__init__.py
Normal file
0
vllm_ascend/eplb/core/__init__.py
Normal file
137
vllm_ascend/eplb/core/eplb_device_transfer_loader.py
Normal file
137
vllm_ascend/eplb/core/eplb_device_transfer_loader.py
Normal file
@@ -0,0 +1,137 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from enum import Enum
|
||||
|
||||
import torch.distributed as dist
|
||||
from vllm.logger import logger
|
||||
|
||||
|
||||
class ExpertWeightUpdateState(Enum):
|
||||
WAITING = 0 # waiting for updated expert_map by EplbWorker
|
||||
READY = 1 # ready for d2d expert weights updating
|
||||
TRANSFERRING = 2 # d2d finished and waiting for updating expert_map into model
|
||||
|
||||
|
||||
class D2DExpertWeightLoader:
|
||||
|
||||
def __init__(self):
|
||||
self.comm_op_list = None
|
||||
self.updated_expert_map = None
|
||||
self.updated_log2phy_map = None
|
||||
self.layer_id = -1 # layer id to be updated
|
||||
self.state = ExpertWeightUpdateState.WAITING
|
||||
self.recv_expert_list = []
|
||||
self.mock_flag = True
|
||||
|
||||
def set_adator(self, eplb_adaptor):
|
||||
self.eplb_adaptor = eplb_adaptor
|
||||
|
||||
def generate_expert_d2d_transfer_task(self, expert_send_info,
|
||||
expert_recv_info, updated_expert_map,
|
||||
layer_id):
|
||||
# When current send/recv and weight.expert_map update tasks are not finished, cannot accept new d2d task
|
||||
if self.state != ExpertWeightUpdateState.WAITING:
|
||||
logger.error(
|
||||
"current d2d weight update tasks are on-going, cannot accept new weight update task"
|
||||
)
|
||||
return
|
||||
|
||||
# If neither send nor receive task is needed for this layer on this rank, return
|
||||
if not (expert_send_info or expert_recv_info):
|
||||
return
|
||||
|
||||
self.updated_expert_map = updated_expert_map
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.comm_op_list = []
|
||||
for send_info in expert_send_info:
|
||||
dst_rank, global_expert_id_to_send = send_info
|
||||
local_expert_id = self.eplb_adaptor.expert_map_per_layer_cpu[
|
||||
layer_id][global_expert_id_to_send].item()
|
||||
for src_tensor in self.eplb_adaptor.expert_param_per_layer[
|
||||
layer_id][local_expert_id]:
|
||||
self.comm_op_list.append(
|
||||
dist.P2POp(dist.isend, src_tensor, dst_rank))
|
||||
|
||||
buffer_tensor_id = 0
|
||||
for recv_info in expert_recv_info:
|
||||
recv_rank, global_expert_id_to_recv = recv_info
|
||||
for buffer_tensor in self.eplb_adaptor.buffer_tensor_list[
|
||||
buffer_tensor_id]:
|
||||
self.comm_op_list.append(
|
||||
dist.P2POp(dist.irecv, buffer_tensor, recv_rank))
|
||||
local_expert_to_replace = self.updated_expert_map[
|
||||
global_expert_id_to_recv].item()
|
||||
self.recv_expert_list.append(
|
||||
(local_expert_to_replace, buffer_tensor_id))
|
||||
buffer_tensor_id += 1
|
||||
|
||||
self.state = ExpertWeightUpdateState.READY
|
||||
|
||||
def set_log2phy_map(self, log2phy_map):
|
||||
self.updated_log2phy_map = log2phy_map
|
||||
|
||||
def asyn_expert_weight_transfer(self, reqs):
|
||||
# Only when send/recv tasks are parsed into self.comm_op_list, d2d send/recv tasks can be luanched
|
||||
if self.state != ExpertWeightUpdateState.READY:
|
||||
return
|
||||
|
||||
# set asynchronous stream for d2d expert weight transfer
|
||||
if self.comm_op_list:
|
||||
ret_list = dist.batch_isend_irecv(self.comm_op_list)
|
||||
reqs.extend(ret_list)
|
||||
|
||||
self.state = ExpertWeightUpdateState.TRANSFERRING
|
||||
|
||||
def update_expert_map_and_weight(self, reqs):
|
||||
# Only after send/recv tasks have been luanched, expert_map and weight can be updated
|
||||
if self.state != ExpertWeightUpdateState.TRANSFERRING:
|
||||
return
|
||||
|
||||
# Waiting for send/recv tasks finish
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
if self.comm_op_list is not None:
|
||||
self.comm_op_list = None
|
||||
|
||||
# update expert_map
|
||||
self.eplb_adaptor.do_update_expert_map(self.layer_id,
|
||||
self.updated_expert_map)
|
||||
|
||||
# update log2phy_map
|
||||
self.eplb_adaptor.do_update_log2phy_map(self.layer_id,
|
||||
self.updated_log2phy_map)
|
||||
|
||||
# update expert weight
|
||||
buffer_tensor_id = 0
|
||||
for recv_expert_info in self.recv_expert_list:
|
||||
local_expert_to_replace, buffer_tensor_id = recv_expert_info
|
||||
self.eplb_adaptor.do_update_expert_weight(self.layer_id,
|
||||
local_expert_to_replace,
|
||||
buffer_tensor_id)
|
||||
|
||||
logger.info(
|
||||
f"[EPLB] finished update expert weight for layer: {self.layer_id}")
|
||||
|
||||
self.recv_expert_list = []
|
||||
self.updated_expert_map = None
|
||||
self.layer_id = -1
|
||||
self.state = ExpertWeightUpdateState.WAITING
|
||||
|
||||
def load_impl(self, old_expert_table, new_expert_table):
|
||||
raise NotImplementedError
|
||||
135
vllm_ascend/eplb/core/eplb_utils.py
Normal file
135
vllm_ascend/eplb/core/eplb_utils.py
Normal file
@@ -0,0 +1,135 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove eplb utils.
|
||||
import random
|
||||
|
||||
import torch
|
||||
from vllm.logger import logger
|
||||
|
||||
|
||||
def determine_default_expert_map(global_expert_num, world_size, rank_id,
|
||||
global_redundant_expert_num):
|
||||
if world_size == 1:
|
||||
local_ids = torch.arange(global_expert_num, dtype=torch.int32)
|
||||
return (global_expert_num, local_ids)
|
||||
|
||||
local_num_experts = global_expert_num // world_size
|
||||
|
||||
expert_map = torch.full((global_expert_num, ), -1, dtype=torch.int32)
|
||||
|
||||
if rank_id < world_size - 1:
|
||||
start = rank_id * local_num_experts
|
||||
end = (rank_id + 1) * local_num_experts
|
||||
local_count = local_num_experts
|
||||
else:
|
||||
start = rank_id * local_num_experts
|
||||
end = global_expert_num
|
||||
local_count = global_expert_num - rank_id * local_num_experts
|
||||
|
||||
if isinstance(global_redundant_expert_num,
|
||||
int) and rank_id < global_redundant_expert_num:
|
||||
local_count += 1
|
||||
if end < global_expert_num:
|
||||
end += 1
|
||||
else:
|
||||
start -= 1
|
||||
|
||||
if isinstance(local_count, int):
|
||||
local_ids = torch.arange(local_count, dtype=torch.int32)
|
||||
expert_map[start:end] = local_ids
|
||||
|
||||
return (local_count, expert_map)
|
||||
|
||||
|
||||
def generate_log2phy_map(expert_map):
|
||||
num_local_experts = expert_map.max() + 1
|
||||
log2phy_map = expert_map.clone()
|
||||
num_ranks, num_global_expert = log2phy_map.shape
|
||||
|
||||
row_indices = torch.arange(num_ranks).view(-1, 1).expand(num_ranks, \
|
||||
num_global_expert) * num_local_experts
|
||||
log2phy_map[log2phy_map != -1] += row_indices[log2phy_map != -1]
|
||||
|
||||
for idx in range(num_global_expert):
|
||||
positive_rank_idx = torch.where(log2phy_map[:, idx] != -1)[0]
|
||||
negative_rank_idx = torch.where(log2phy_map[:, idx] == -1)[0]
|
||||
num_rank_holding_expert = positive_rank_idx.size(0)
|
||||
|
||||
if num_rank_holding_expert == 0:
|
||||
log2phy_map[:, idx] = torch.full((num_ranks, ),
|
||||
0,
|
||||
dtype=log2phy_map.dtype)
|
||||
|
||||
if num_rank_holding_expert == 1:
|
||||
log2phy_map[negative_rank_idx, idx] = torch.full(
|
||||
(num_ranks - 1, ),
|
||||
log2phy_map[positive_rank_idx, idx].item(),
|
||||
dtype=log2phy_map.dtype)
|
||||
else:
|
||||
try:
|
||||
random_list = [
|
||||
random.choice(log2phy_map[positive_rank_idx, idx])
|
||||
for _ in range(num_ranks - num_rank_holding_expert)
|
||||
]
|
||||
log2phy_map[negative_rank_idx,
|
||||
idx] = torch.tensor(random_list,
|
||||
dtype=log2phy_map.dtype)
|
||||
except Exception as e:
|
||||
logger.error(f"Fail to get log2phy_map: {str(e)}")
|
||||
|
||||
return log2phy_map
|
||||
|
||||
|
||||
def determine_default_log2phy_map(global_expert_num, world_size, rank_id,
|
||||
global_redundant_expert_num):
|
||||
if world_size == 1:
|
||||
local_ids = torch.arange(global_expert_num, dtype=torch.int32)
|
||||
expert_map_all = local_ids.unsqueeze(0).expand(world_size, -1)
|
||||
log2phy_map_all = generate_log2phy_map(expert_map_all)
|
||||
return log2phy_map_all[rank_id]
|
||||
|
||||
local_num_experts = global_expert_num // world_size
|
||||
|
||||
expert_map_all = torch.full((world_size, global_expert_num),
|
||||
-1,
|
||||
dtype=torch.int32)
|
||||
|
||||
for r in range(world_size):
|
||||
if r < world_size - 1:
|
||||
start = r * local_num_experts
|
||||
end = (r + 1) * local_num_experts
|
||||
local_count = local_num_experts
|
||||
else:
|
||||
start = r * local_num_experts
|
||||
end = global_expert_num
|
||||
local_count = global_expert_num - r * local_num_experts
|
||||
|
||||
if isinstance(global_redundant_expert_num,
|
||||
int) and rank_id < global_redundant_expert_num:
|
||||
local_count += 1
|
||||
if end < global_expert_num:
|
||||
end += 1
|
||||
else:
|
||||
start -= 1
|
||||
|
||||
if isinstance(local_count, int):
|
||||
local_ids = torch.arange(local_count, dtype=torch.int32)
|
||||
expert_map_all[r, start:end] = local_ids
|
||||
|
||||
log2phy_map_all = generate_log2phy_map(expert_map_all)
|
||||
|
||||
return log2phy_map_all[rank_id]
|
||||
436
vllm_ascend/eplb/core/eplb_worker.py
Normal file
436
vllm_ascend/eplb/core/eplb_worker.py
Normal file
@@ -0,0 +1,436 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from multiprocessing import Process, Queue
|
||||
from typing import Any
|
||||
|
||||
import networkx as nx # type: ignore
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.eplb.core.eplb_utils import generate_log2phy_map
|
||||
from vllm_ascend.eplb.core.policy.policy_factory import (DynamicConfig,
|
||||
PolicyFactory)
|
||||
|
||||
|
||||
class EplbWorker:
|
||||
|
||||
def __init__(self, shared_dict, policy_type, enable_d2d: bool = True):
|
||||
self.policy_type = policy_type
|
||||
self.policy = PolicyFactory.generate_policy(policy_type,
|
||||
DynamicConfig())
|
||||
self.shared_dict = shared_dict
|
||||
self.old_expert_maps = None
|
||||
self.enable_d2d = enable_d2d
|
||||
self.rank_id = dist.get_rank()
|
||||
|
||||
def do_update(self):
|
||||
# put data in to queue
|
||||
# in process self.policy.generate_policy()
|
||||
# get epxert table && tensor
|
||||
|
||||
# async stream
|
||||
# D2D
|
||||
# H2D
|
||||
# Get initial expert_map
|
||||
torch.set_num_threads(1)
|
||||
if self.old_expert_maps is None:
|
||||
self.old_expert_maps = self.get_init_expert_maps()
|
||||
if self.old_expert_maps is not None:
|
||||
self.num_local_experts = self.old_expert_maps.max() + 1
|
||||
else:
|
||||
raise ValueError("Failed to get expert_maps from shared_dict.")
|
||||
|
||||
# Get MOE load information
|
||||
load_info = self.fetch_and_sum_load_info()
|
||||
if load_info is None:
|
||||
return
|
||||
|
||||
# Get the updated expert table based on the workload information
|
||||
old_placement = self.global2local(self.old_expert_maps,
|
||||
self.num_local_experts)
|
||||
_, _, new_placement = self.calculate_rebalance_experts(
|
||||
load_info, old_placement)
|
||||
|
||||
if not torch.is_tensor(new_placement):
|
||||
new_placement = torch.tensor(new_placement)
|
||||
self.check_expert_placement(old_placement, new_placement)
|
||||
new_expert_maps = self.local2global(new_placement)
|
||||
self.update_expert_map(new_expert_maps)
|
||||
|
||||
update_info = self.compose_expert_update_info_greedy(
|
||||
new_expert_maps, self.old_expert_maps)
|
||||
self.old_expert_maps = new_expert_maps
|
||||
logger.info("EPLB Process compute complete")
|
||||
|
||||
packed_update_info = self.pack_update_info(update_info)
|
||||
|
||||
return packed_update_info
|
||||
|
||||
def check_expert_placement(self, old_placement, new_placement):
|
||||
num_layers = old_placement.shape[0]
|
||||
num_ranks = old_placement.shape[1]
|
||||
|
||||
for layer_id in range(num_layers):
|
||||
# check if any logical expert is not placed on any rank
|
||||
if torch.unique(new_placement[layer_id]).numel() < torch.unique(
|
||||
old_placement[layer_id]).numel():
|
||||
logger.error(
|
||||
f"There exists expert not placed on any rank in layer {layer_id}"
|
||||
)
|
||||
new_placement[layer_id] = old_placement[layer_id]
|
||||
continue
|
||||
|
||||
for rank_id in range(num_ranks):
|
||||
new_placement_check = new_placement[layer_id][rank_id]
|
||||
old_placement_check = old_placement[layer_id][rank_id]
|
||||
|
||||
# check if same logical experts are placed on the same NPU
|
||||
if new_placement_check.numel() != torch.unique(
|
||||
new_placement_check).numel():
|
||||
logger.error(
|
||||
f"Replicated experts are placed on the same NPU, expert placement on layer {layer_id}, rank {rank_id} is invalid"
|
||||
)
|
||||
new_placement[layer_id] = old_placement[layer_id]
|
||||
break
|
||||
|
||||
# check if there is any experts movement inside one NPU
|
||||
expert_not_move = torch.isin(new_placement_check,
|
||||
old_placement_check)
|
||||
if not torch.equal(new_placement_check[expert_not_move],
|
||||
old_placement_check[expert_not_move]):
|
||||
logger.error(
|
||||
f"There exists expert movement inside NPU, expert placement on layer {layer_id}, rank {rank_id} is invalid"
|
||||
)
|
||||
new_placement[layer_id] = old_placement[layer_id]
|
||||
break
|
||||
|
||||
def compose_expert_update_info_bipartite(self, updated_expert_maps_org,
|
||||
current_expert_maps_org):
|
||||
# transform numpy array to torch tensor
|
||||
updated_expert_maps = updated_expert_maps_org.clone()
|
||||
current_expert_maps = current_expert_maps_org.clone()
|
||||
updated_expert_maps = np.array(updated_expert_maps)
|
||||
current_expert_maps = np.array(current_expert_maps)
|
||||
|
||||
num_layers = current_expert_maps.shape[0]
|
||||
|
||||
for layer_id in range(num_layers):
|
||||
updated_expert_maps_this_layer = updated_expert_maps[layer_id]
|
||||
current_expert_maps_this_layer = current_expert_maps[layer_id]
|
||||
updated_expert_maps_this_layer_org = updated_expert_maps_org[
|
||||
layer_id]
|
||||
|
||||
from typing import Any
|
||||
|
||||
expert_send_info_this_layer: dict[Any, Any] = {}
|
||||
expert_recv_info_this_layer: dict[Any, Any] = {}
|
||||
|
||||
# Guard Clause: if there is no expert weight update, avoid subsequent processing
|
||||
if (np.equal(updated_expert_maps_this_layer,
|
||||
current_expert_maps_this_layer)).all():
|
||||
yield (expert_send_info_this_layer,
|
||||
expert_recv_info_this_layer,
|
||||
updated_expert_maps_this_layer_org, layer_id)
|
||||
|
||||
# Parse expert_ids each rank needs to receive from other ranks
|
||||
dst_rank_indices, experts_to_recv = np.where(
|
||||
(current_expert_maps_this_layer == -1)
|
||||
& (updated_expert_maps_this_layer != -1))
|
||||
|
||||
# record src ranks for potential transfer
|
||||
src_ranks_set = dict()
|
||||
for idx in range(len(dst_rank_indices)):
|
||||
expert_id = experts_to_recv[idx].item()
|
||||
if expert_id not in src_ranks_set:
|
||||
src_ranks_set[expert_id] = np.where(
|
||||
current_expert_maps_this_layer[:, expert_id] != -1)[0]
|
||||
|
||||
# loop until all experts are scheduled
|
||||
while len(dst_rank_indices) > 0:
|
||||
# construct bipartite graph
|
||||
graph_expert_update: nx.Graph = nx.Graph()
|
||||
for idx in range(len(dst_rank_indices)):
|
||||
dst_rank_id = dst_rank_indices[idx].item()
|
||||
expert_id = experts_to_recv[idx].item()
|
||||
# add src ranks
|
||||
src_rank_ids = src_ranks_set[expert_id]
|
||||
graph_expert_update.add_nodes_from(src_rank_ids,
|
||||
bipartite=0)
|
||||
# add dest rank
|
||||
graph_expert_update.add_node(str(dst_rank_id), bipartite=1)
|
||||
# add edges
|
||||
for src_rank_id in src_rank_ids:
|
||||
graph_expert_update.add_edge(src_rank_id,
|
||||
str(dst_rank_id))
|
||||
|
||||
# graph may not be connected
|
||||
connected_components = list(
|
||||
nx.connected_components(graph_expert_update))
|
||||
all_matches = {}
|
||||
# matching in this loop
|
||||
for i, component in enumerate(connected_components):
|
||||
subgraph = graph_expert_update.subgraph(component)
|
||||
component_matching = nx.bipartite.maximum_matching(
|
||||
subgraph)
|
||||
all_matches.update(component_matching)
|
||||
|
||||
for src_rank, dst_rank in all_matches.items():
|
||||
dst_rank = int(dst_rank)
|
||||
assert src_rank != dst_rank
|
||||
if graph_expert_update.nodes[src_rank]['bipartite'] == 0:
|
||||
# currently not scheduled experts in rank dst_rank
|
||||
experts_v = experts_to_recv[np.where(
|
||||
dst_rank_indices == dst_rank)]
|
||||
# src: src_rank, dest: dst_rank, expert: expert_id
|
||||
expert_id = np.intersect1d(
|
||||
experts_v,
|
||||
np.where(current_expert_maps_this_layer[src_rank]
|
||||
!= -1))[0]
|
||||
|
||||
# record send/rcv pairs
|
||||
if src_rank not in expert_send_info_this_layer:
|
||||
expert_send_info_this_layer[src_rank] = []
|
||||
if dst_rank not in expert_recv_info_this_layer:
|
||||
expert_recv_info_this_layer[dst_rank] = []
|
||||
expert_send_info_this_layer[src_rank].append(
|
||||
(dst_rank, expert_id))
|
||||
expert_recv_info_this_layer[dst_rank].append(
|
||||
(src_rank, expert_id))
|
||||
|
||||
remove_index = np.where(
|
||||
np.logical_and(dst_rank_indices == dst_rank,
|
||||
experts_to_recv == expert_id))
|
||||
|
||||
# update
|
||||
dst_rank_indices = np.delete(dst_rank_indices,
|
||||
remove_index)
|
||||
experts_to_recv = np.delete(experts_to_recv,
|
||||
remove_index)
|
||||
|
||||
yield (expert_send_info_this_layer, expert_recv_info_this_layer,
|
||||
updated_expert_maps_this_layer_org, layer_id)
|
||||
|
||||
# TODO: Here only expert weight exchange is considered, need to be extended to cover other weight update cases
|
||||
def compose_expert_update_info_greedy(self, updated_expert_maps,
|
||||
current_expert_maps):
|
||||
num_layers = current_expert_maps.shape[0]
|
||||
for layer_id in range(num_layers):
|
||||
updated_expert_maps_this_layer = updated_expert_maps[layer_id]
|
||||
current_expert_maps_this_layer = current_expert_maps[layer_id]
|
||||
|
||||
expert_send_info_this_layer: dict[Any, Any] = {}
|
||||
expert_recv_info_this_layer: dict[Any, Any] = {}
|
||||
|
||||
# Guard Clause: if there is no expert weight update, avoid subsequent processing
|
||||
if torch.equal(updated_expert_maps_this_layer,
|
||||
current_expert_maps_this_layer):
|
||||
yield (expert_send_info_this_layer,
|
||||
expert_recv_info_this_layer,
|
||||
updated_expert_maps_this_layer, layer_id)
|
||||
|
||||
# Parse expert_ids each rank needs to receive from other ranks
|
||||
dst_rank_indices, experts_to_recv = torch.where((current_expert_maps_this_layer == -1) \
|
||||
& (updated_expert_maps_this_layer != -1))
|
||||
|
||||
# Parse expert_ids each rank needs to send to other ranks
|
||||
src_rank_indices, experts_to_send = torch.where((current_expert_maps_this_layer != -1) \
|
||||
& (updated_expert_maps_this_layer == -1))
|
||||
|
||||
for idx in range(len(dst_rank_indices)):
|
||||
dst_rank_id = dst_rank_indices[idx].item()
|
||||
expert_id = experts_to_recv[idx].item()
|
||||
if dst_rank_id not in expert_recv_info_this_layer:
|
||||
expert_recv_info_this_layer[dst_rank_id] = []
|
||||
|
||||
if not torch.isin(torch.tensor(expert_id),
|
||||
experts_to_send).any():
|
||||
# if expert_id are not sent out from any npu, it will be copied from one npu holding this expert
|
||||
candidate_src_rank_indices = torch.where(
|
||||
current_expert_maps_this_layer[:, expert_id] != -1)[0]
|
||||
else:
|
||||
candidate_src_rank_indices = src_rank_indices[
|
||||
experts_to_send == expert_id]
|
||||
|
||||
# TODO: improve selection criterion of npu sending expert_id considering such as intra-node or inter-node...
|
||||
src_rank_id = candidate_src_rank_indices[0].item()
|
||||
if src_rank_id not in expert_send_info_this_layer:
|
||||
expert_send_info_this_layer[src_rank_id] = []
|
||||
|
||||
expert_send_info_this_layer[src_rank_id].append(
|
||||
(dst_rank_id, expert_id))
|
||||
expert_recv_info_this_layer[dst_rank_id].append(
|
||||
(src_rank_id, expert_id))
|
||||
|
||||
yield (expert_send_info_this_layer, expert_recv_info_this_layer,
|
||||
updated_expert_maps_this_layer, layer_id)
|
||||
|
||||
def calculate_rebalance_experts(self, load_info, old_placement):
|
||||
"""
|
||||
Compute `new_map` by calling the `rebalance_experts` method of the policy instance.
|
||||
"""
|
||||
if self.old_expert_maps is None:
|
||||
return False, None, None
|
||||
|
||||
changed, priority, new_map = self.policy.rebalance_experts(
|
||||
old_placement, load_info)
|
||||
return changed, priority, new_map
|
||||
|
||||
def get_init_expert_maps(self):
|
||||
"""
|
||||
Read the initial expert_map from shared_dict.
|
||||
"""
|
||||
return self.shared_dict.get("expert_maps", None)
|
||||
|
||||
def fetch_and_sum_load_info(self):
|
||||
"""
|
||||
Each time the subprocess is awakened, read the latest moe_load
|
||||
(shape: [num_moe_layers, num_experts_per_layer]) from shared_dict.
|
||||
"""
|
||||
return self.shared_dict.get("moe_load", None)
|
||||
|
||||
def update_expert_map(self, expert_maps):
|
||||
|
||||
self.shared_dict["expert_maps"] = expert_maps
|
||||
|
||||
def global2local(self, placement: torch.Tensor,
|
||||
E_local: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
L, G, _ = placement.shape
|
||||
device = placement.device
|
||||
|
||||
pt_local = torch.full((L, G, E_local),
|
||||
fill_value=-1,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
valid = placement >= 0
|
||||
l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True)
|
||||
|
||||
slot_idx = placement[l_idx, g_idx, k_idx]
|
||||
|
||||
pt_local[l_idx, g_idx, slot_idx] = k_idx
|
||||
|
||||
return pt_local
|
||||
|
||||
def local2global(self, placement_local: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
L, G, E_local = placement_local.shape
|
||||
device = placement_local.device
|
||||
|
||||
max_id = torch.max(placement_local)
|
||||
E_global = (max_id + 1).item() if max_id >= 0 else 0
|
||||
|
||||
if E_global == 0:
|
||||
return torch.empty((L, G, 0), dtype=torch.long, device=device)
|
||||
|
||||
placement_global = torch.full((L, G, E_global),
|
||||
fill_value=-1,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
valid = placement_local >= 0
|
||||
l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True)
|
||||
gid_idx = placement_local[l_idx, g_idx, slot_idx]
|
||||
|
||||
placement_global[l_idx, g_idx, gid_idx] = slot_idx
|
||||
|
||||
return placement_global
|
||||
|
||||
def pack_update_info(self, update_info_generator):
|
||||
"""
|
||||
Pack a list of update info tuples for efficient IPC.
|
||||
"""
|
||||
send_all = []
|
||||
recv_all = []
|
||||
maps = []
|
||||
log2phy_all = []
|
||||
layer_ids = []
|
||||
|
||||
for send_info, recv_info, new_expert_map, layer_id in update_info_generator:
|
||||
|
||||
send_info_this_rank = send_info[
|
||||
self.rank_id] if self.rank_id in send_info else []
|
||||
recv_info_this_rank = recv_info[
|
||||
self.rank_id] if self.rank_id in recv_info else []
|
||||
send_all.append(send_info_this_rank)
|
||||
recv_all.append(recv_info_this_rank)
|
||||
|
||||
maps.append(new_expert_map[self.rank_id].numpy().tolist())
|
||||
|
||||
log2phy_map = generate_log2phy_map(new_expert_map)
|
||||
log2phy_all.append(log2phy_map[self.rank_id].numpy().tolist())
|
||||
|
||||
layer_ids.append(layer_id)
|
||||
|
||||
return list(zip(send_all, recv_all, maps, log2phy_all, layer_ids))
|
||||
|
||||
|
||||
class EplbProcess:
|
||||
|
||||
def __init__(self,
|
||||
shared_dict,
|
||||
policy_type: int = 0,
|
||||
enable_d2d: bool = True):
|
||||
"""
|
||||
Args:
|
||||
shared_dict: Cross-process shared dict returned by Manager().dict()
|
||||
policy_type: Integer passed to PolicyFactory.generate_policy
|
||||
enable_d2d: Whether to enable D2D loading
|
||||
"""
|
||||
self.shared_dict = shared_dict
|
||||
self.policy_type = policy_type
|
||||
self.enable_d2d = enable_d2d
|
||||
self.planner_q: Queue[Any] = Queue()
|
||||
self.block_update_q: Queue[Any] = Queue(maxsize=1)
|
||||
|
||||
# Create EplbWorker instance
|
||||
self.worker = EplbWorker(self.shared_dict, self.policy_type,
|
||||
self.enable_d2d)
|
||||
|
||||
def worker_process(self, planner_q, block_update_q):
|
||||
"""
|
||||
Subprocess entry: bind to specified NPU, loop waiting for planner_q to wake up, call do_update, then notify main process update is complete.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
planner_q.get()
|
||||
|
||||
packed_update_info = self.worker.do_update()
|
||||
|
||||
while True:
|
||||
if not block_update_q.empty():
|
||||
continue
|
||||
block_update_q.put(packed_update_info)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[EPLB subprocess Exiting due to error: {e}",
|
||||
exc_info=True)
|
||||
break
|
||||
|
||||
def _launch_process(self):
|
||||
"""
|
||||
Use spawn method to launch subprocess and return (planner_q, block_update_q, proc).
|
||||
"""
|
||||
proc = Process(target=self.worker_process,
|
||||
args=(self.planner_q, self.block_update_q),
|
||||
daemon=True)
|
||||
|
||||
proc.start()
|
||||
return proc
|
||||
0
vllm_ascend/eplb/core/policy/__init__.py
Normal file
0
vllm_ascend/eplb/core/policy/__init__.py
Normal file
42
vllm_ascend/eplb/core/policy/policy_abstract.py
Normal file
42
vllm_ascend/eplb/core/policy/policy_abstract.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
|
||||
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy.
|
||||
from abc import abstractmethod
|
||||
|
||||
|
||||
class DynamicConfig:
|
||||
placement_policy = None
|
||||
|
||||
max_transferred_expert_per_layer = 100 # Maximum number of experts that can be migrated per layer on a single host
|
||||
ep_worldsize = 64 # Total number of dies across the entire cluster where experts are distributed
|
||||
num_die_per_host = 8 # Number of dies on each host machine
|
||||
|
||||
|
||||
class EplbPolicy:
|
||||
|
||||
def __init__(self, config: DynamicConfig):
|
||||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
def rebalance_experts(self, current_expert_table, expert_workload):
|
||||
"""
|
||||
Pass in the weights and return expert replication and placement under relevant constraints.
|
||||
INPUT:
|
||||
current_expert_table: [layerId, rankId, expert_num_i]
|
||||
expert_workload = expert_table[layer0][rankId][expert_num_i]
|
||||
|
||||
RETURNED: (res, expert_table)
|
||||
res:
|
||||
1 -- table_changed
|
||||
0 -- not_changed
|
||||
|
||||
expert_table: [layerId, rankId, expert_num_i]
|
||||
expert_num_i --- [0, MaxExpertPerRank]
|
||||
expertID = expert_table[layer0][rankId][expert_num_i]
|
||||
array_values:
|
||||
[0, 1, 2, 3, 248]
|
||||
[4, 5, 6, 7, 254]
|
||||
[8, 9, 10, 11, 71]
|
||||
...
|
||||
[252, 253, 254, 255, 0]
|
||||
"""
|
||||
pass
|
||||
389
vllm_ascend/eplb/core/policy/policy_dynamic_ep.py
Normal file
389
vllm_ascend/eplb/core/policy/policy_dynamic_ep.py
Normal file
@@ -0,0 +1,389 @@
|
||||
# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
|
||||
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy.
|
||||
from collections import defaultdict
|
||||
from typing import cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .policy_abstract import DynamicConfig, EplbPolicy
|
||||
|
||||
|
||||
class DynamicTable:
|
||||
# workload_table:
|
||||
# 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: workload (heat) at the corresponding position
|
||||
# Size: number of layers * number of GPUs * number of experts per GPU per layer
|
||||
# The element at (i, j, k) represents the workload (heat) of the k-th expert on the j-th GPU in the i-th layer
|
||||
# For experts that are not available or collected, the value is set to -1
|
||||
workload_table = None
|
||||
|
||||
# placement_table:
|
||||
# 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: physical expert ID at the corresponding position
|
||||
# Size: number of layers * number of GPUs * number of experts per GPU per layer
|
||||
# The element at (i, j, k) represents the physical expert ID of the k-th expert on the j-th GPU in the i-th layer
|
||||
# For experts that are not available or collected, the value is set to -1
|
||||
placement_table = None
|
||||
|
||||
|
||||
class DynamicEplb(EplbPolicy):
|
||||
|
||||
def __init__(self, config: DynamicConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@staticmethod
|
||||
def add_redundant(current_expert_table, expert_workload,
|
||||
num_original_expert):
|
||||
layer_num, npu_num, experts_per_npu = expert_workload.shape
|
||||
workload_new = np.zeros((layer_num, num_original_expert))
|
||||
for layer_idx in range(layer_num):
|
||||
workload_dict: dict[int, int] = defaultdict(int)
|
||||
placement_layer = current_expert_table[layer_idx].copy()
|
||||
workload_layer = expert_workload[layer_idx].copy()
|
||||
for npu_idx in range(npu_num):
|
||||
for expert_idx in range(experts_per_npu):
|
||||
workload_dict[placement_layer[npu_idx][
|
||||
expert_idx]] += workload_layer[npu_idx][expert_idx]
|
||||
for expert_idx in range(num_original_expert):
|
||||
workload_new[layer_idx][expert_idx] = workload_dict[expert_idx]
|
||||
return workload_new
|
||||
|
||||
@staticmethod
|
||||
# Split hot (high-load) experts into redundant experts
|
||||
def original_compute_balanced_pack_redundancy(origin_weights, card_num,
|
||||
num_redundancy_expert):
|
||||
# Step 1: Sort the items by weight in descending order (we are sorting by weight now)
|
||||
# Sort based on the second element (the second value of each tuple)
|
||||
route_expert_num = len(origin_weights)
|
||||
route_expert_redundancy: list[list[int]] = [
|
||||
[] for _ in range(route_expert_num)
|
||||
]
|
||||
for i in range(num_redundancy_expert):
|
||||
sorted_indices = np.argsort([t[1] for t in origin_weights],
|
||||
kind='stable')[::-1]
|
||||
weights = [origin_weights[idx] for idx in sorted_indices]
|
||||
tmp_raw_weight = weights[0][1] * (
|
||||
len(route_expert_redundancy[weights[0][0]]) + 1)
|
||||
route_expert_redundancy[weights[0][0]].append(route_expert_num + i)
|
||||
avg_weight = tmp_raw_weight / (
|
||||
len(route_expert_redundancy[weights[0][0]]) + 1)
|
||||
weights[0] = (weights[0][0], avg_weight)
|
||||
origin_weights = weights
|
||||
|
||||
# Step 2: Calculate the number of items per box
|
||||
expert_num = route_expert_num + num_redundancy_expert
|
||||
items_per_box = expert_num // card_num # Number of items per box
|
||||
remaining_items = expert_num % card_num # Number of items per box
|
||||
|
||||
# Step 3: Initialize card_num boxes with empty lists to store item IDs
|
||||
boxes: list[list[int]] = [[] for _ in range(card_num)]
|
||||
boxes_weights: list[list[float]] = [[] for _ in range(card_num)]
|
||||
box_weights = [0] * card_num # To store the total weight of each box
|
||||
box_counts = [0] * card_num # To store the number of items in each box
|
||||
index = 0
|
||||
for i in range(route_expert_num):
|
||||
redundancy_num = len(route_expert_redundancy[i])
|
||||
for _ in range(redundancy_num):
|
||||
cur_weight = 0
|
||||
for item, weight in origin_weights:
|
||||
if item == i:
|
||||
cur_weight = weight
|
||||
|
||||
boxes[index].append(i)
|
||||
boxes_weights[index].append(cur_weight)
|
||||
box_weights[index] += cur_weight
|
||||
box_counts[index] += 1
|
||||
index += 1
|
||||
|
||||
sorted_indices = np.argsort([t[1] for t in origin_weights],
|
||||
kind='stable')[::-1]
|
||||
origin_weights = [origin_weights[idx] for idx in sorted_indices]
|
||||
# Step 4: Distribute items into boxes based on weight
|
||||
for item_id, weight in origin_weights:
|
||||
# Find the box with the least items but not full
|
||||
min_box_index = -1
|
||||
for i in range(card_num):
|
||||
if item_id in boxes[i]:
|
||||
continue
|
||||
# Only choose boxes that still have space (box_counts[i] < items_per_box)
|
||||
if box_counts[i] < items_per_box or (box_counts[i]
|
||||
== items_per_box
|
||||
and remaining_items > 0):
|
||||
if min_box_index == -1 or box_weights[i] < box_weights[
|
||||
min_box_index]:
|
||||
min_box_index = i
|
||||
|
||||
# Place the item (id) into the selected box
|
||||
boxes[min_box_index].append(item_id)
|
||||
boxes_weights[min_box_index].append(weight)
|
||||
box_weights[min_box_index] += weight
|
||||
box_counts[min_box_index] += 1
|
||||
|
||||
# If there's an imbalance in the remaining items, reduce the "remaining_items" counter
|
||||
if box_counts[min_box_index] == (items_per_box +
|
||||
1) and remaining_items > 0:
|
||||
remaining_items -= 1
|
||||
|
||||
# Step 5: Output each box's contents and total weight
|
||||
result = []
|
||||
for i in range(card_num):
|
||||
result.append({
|
||||
"box_index": i + 1,
|
||||
"items": boxes[i], # List of item IDs in the box
|
||||
"weight": boxes_weights[i],
|
||||
"total_weight": box_weights[i], # Total weight in this box
|
||||
"item_count": box_counts[i] # Number of items in the box
|
||||
})
|
||||
|
||||
return result, boxes
|
||||
|
||||
# Split hot (high-load) experts into redundant experts
|
||||
@staticmethod
|
||||
def compute_balanced_pack_redundancy(origin_weights, card_num,
|
||||
num_redundancy_expert):
|
||||
route_expert_num = len(origin_weights)
|
||||
route_expert_redundancy: list[list[int]] = [
|
||||
[] for _ in range(route_expert_num)
|
||||
]
|
||||
for i in range(num_redundancy_expert):
|
||||
sorted_indices = np.argsort([t[1] for t in origin_weights],
|
||||
kind='stable')[::-1]
|
||||
weights = [origin_weights[idx] for idx in sorted_indices]
|
||||
tmp_raw_weight = weights[0][1] * (
|
||||
len(route_expert_redundancy[weights[0][0]]) + 1)
|
||||
route_expert_redundancy[weights[0][0]].append(route_expert_num + i)
|
||||
avg_weight = tmp_raw_weight / (
|
||||
len(route_expert_redundancy[weights[0][0]]) + 1)
|
||||
weights[0] = (weights[0][0], avg_weight)
|
||||
origin_weights = weights
|
||||
|
||||
expert_num = route_expert_num + num_redundancy_expert
|
||||
if card_num == 0:
|
||||
raise RuntimeError("card_num can not be 0.")
|
||||
items_per_box = expert_num // card_num
|
||||
remaining_items = expert_num % card_num
|
||||
|
||||
boxes: list[list[int]] = [[] for _ in range(card_num)]
|
||||
boxes_weights: list[list[float]] = [[] for _ in range(card_num)]
|
||||
box_weights = [0] * card_num
|
||||
box_counts = [0] * card_num
|
||||
|
||||
all_weights = np.zeros((expert_num, ), dtype='object')
|
||||
all_weights[:route_expert_num] = origin_weights
|
||||
|
||||
index = route_expert_num
|
||||
for i in range(route_expert_num):
|
||||
redundancy_num = len(route_expert_redundancy[i])
|
||||
for _ in range(redundancy_num):
|
||||
for item, weight in origin_weights:
|
||||
if item == i:
|
||||
all_weights[index] = (item, weight)
|
||||
index += 1
|
||||
|
||||
sorted_indices = np.argsort([t[1] for t in all_weights],
|
||||
kind='stable')[::-1]
|
||||
all_weights = [all_weights[idx] for idx in sorted_indices]
|
||||
for item_id, weight in all_weights:
|
||||
min_box_index = -1
|
||||
for i in range(card_num):
|
||||
if box_counts[i] < items_per_box or (box_counts[i]
|
||||
== items_per_box
|
||||
and remaining_items > 0):
|
||||
if min_box_index == -1 or box_weights[i] < box_weights[
|
||||
min_box_index]:
|
||||
if item_id not in boxes[i]:
|
||||
min_box_index = i
|
||||
|
||||
boxes[min_box_index].append(item_id)
|
||||
boxes_weights[min_box_index].append(weight)
|
||||
box_weights[min_box_index] += weight
|
||||
box_counts[min_box_index] += 1
|
||||
|
||||
if box_counts[min_box_index] == (items_per_box +
|
||||
1) and remaining_items > 0:
|
||||
remaining_items -= 1
|
||||
|
||||
result = []
|
||||
for i in range(card_num):
|
||||
result.append({
|
||||
"box_index": i + 1,
|
||||
"items": boxes[i],
|
||||
"weight": boxes_weights[i],
|
||||
"total_weight": box_weights[i],
|
||||
"item_count": box_counts[i]
|
||||
})
|
||||
|
||||
return result, boxes
|
||||
|
||||
# Scheme without redundant experts
|
||||
@staticmethod
|
||||
def compute_balanced_pack(origin_weights, card_num):
|
||||
sorted_indices = np.argsort([t[1] for t in origin_weights])[::-1]
|
||||
weights = origin_weights[sorted_indices]
|
||||
expert_num = len(weights)
|
||||
if card_num == 0:
|
||||
raise RuntimeError("card_num can not be 0.")
|
||||
items_per_box = expert_num // card_num
|
||||
remaining_items = expert_num % card_num
|
||||
|
||||
boxes: list[list[int]] = [[] for _ in range(card_num)]
|
||||
boxes_weights: list[list[float]] = [[] for _ in range(card_num)]
|
||||
box_weights = [0] * card_num
|
||||
box_counts = [0] * card_num
|
||||
|
||||
for item_id, weight in weights:
|
||||
min_box_index = -1
|
||||
for i in range(card_num):
|
||||
if box_counts[i] < items_per_box or (box_counts[i]
|
||||
== items_per_box
|
||||
and remaining_items > 0):
|
||||
if min_box_index == -1 or box_weights[i] < box_weights[
|
||||
min_box_index]:
|
||||
min_box_index = i
|
||||
|
||||
boxes[min_box_index].append(item_id)
|
||||
boxes_weights[min_box_index].append(weight)
|
||||
box_weights[min_box_index] += weight
|
||||
box_counts[min_box_index] += 1
|
||||
|
||||
if box_counts[min_box_index] == (items_per_box +
|
||||
1) and remaining_items > 0:
|
||||
remaining_items -= 1
|
||||
|
||||
result = []
|
||||
for i in range(card_num):
|
||||
result.append({
|
||||
"box_index": i + 1,
|
||||
"items": boxes[i],
|
||||
"weight": boxes_weights[i],
|
||||
"total_weight": box_weights[i],
|
||||
"item_count": box_counts[i]
|
||||
})
|
||||
|
||||
return result, boxes
|
||||
|
||||
@staticmethod
|
||||
def get_redundant_num(npu_num, counts):
|
||||
redundant_num_each_npu: int = np.sum(counts - 1)
|
||||
return redundant_num_each_npu
|
||||
|
||||
@staticmethod
|
||||
def calculate_max_heat_per_layer(workload_table, layer_num):
|
||||
max_heat_per_layer: list[float] = []
|
||||
for layer_idx in range(layer_num):
|
||||
npu_heats_now = np.sum(workload_table[layer_idx], axis=1)
|
||||
max_heat_per_layer.append(np.max(npu_heats_now))
|
||||
return max_heat_per_layer
|
||||
|
||||
@staticmethod
|
||||
def constraint_expert_local_exchange(current_expert_table,
|
||||
global_deployment):
|
||||
for layer_id in range(len(global_deployment)):
|
||||
for card_id in range(len(global_deployment[layer_id])):
|
||||
current_list = [
|
||||
int(x) for x in current_expert_table[layer_id][card_id]
|
||||
]
|
||||
new_list = [
|
||||
int(x) for x in global_deployment[layer_id][card_id]
|
||||
]
|
||||
num = len(new_list)
|
||||
|
||||
new_index = [-1] * num
|
||||
new_result = [-1] * num
|
||||
remaining_elements = []
|
||||
|
||||
for i in range(num):
|
||||
flag = True
|
||||
for j in range(num):
|
||||
if new_list[i] == current_list[j] and new_index[
|
||||
j] == -1:
|
||||
new_index[j] = 0
|
||||
new_result[j] = current_list[j]
|
||||
flag = False
|
||||
break
|
||||
if flag:
|
||||
remaining_elements.append(new_list[i])
|
||||
|
||||
index = 0
|
||||
for k in range(num):
|
||||
if new_result[k] == -1:
|
||||
new_result[k] = remaining_elements[index]
|
||||
index += 1
|
||||
|
||||
global_deployment[layer_id][card_id] = new_result
|
||||
|
||||
return global_deployment
|
||||
|
||||
def rebalance_experts(self, current_expert_table, expert_workload):
|
||||
|
||||
info = DynamicTable()
|
||||
info.workload_table = np.array(expert_workload)
|
||||
info.placement_table = np.array(current_expert_table)
|
||||
assert info.workload_table is not None
|
||||
layer_num, num_npus, experts_per_npu = info.workload_table.shape
|
||||
assert info.placement_table is not None
|
||||
row = cast(np.ndarray, info.placement_table[0])
|
||||
expert_ids, counts = np.unique(row, return_counts=True)
|
||||
num_redundancy_expert = self.get_redundant_num(num_npus, counts)
|
||||
num_original_expert = len(expert_ids)
|
||||
layer_workloads = self.add_redundant(info.placement_table,
|
||||
info.workload_table,
|
||||
num_original_expert)
|
||||
max_heat_per_layer_before = self.calculate_max_heat_per_layer(
|
||||
info.workload_table, layer_num)
|
||||
npu_heat_all_origin = sum(max_heat_per_layer_before)
|
||||
|
||||
# Perform load balancing and deploy redundant experts
|
||||
layer_num = layer_workloads.shape[0]
|
||||
expert_num = layer_workloads.shape[1]
|
||||
# Validate that the number of experts, number of cards, and number of redundant experts do not exceed the number of cards
|
||||
if num_original_expert != expert_num:
|
||||
raise ValueError(
|
||||
f"the number of original experts {num_original_expert} must be equal to expert_num {expert_num}"
|
||||
)
|
||||
|
||||
if num_npus <= 0:
|
||||
raise ValueError("the number of NPUs must be greater than 0")
|
||||
|
||||
if num_npus < num_redundancy_expert:
|
||||
raise ValueError(
|
||||
f"the number of NPUs {num_npus} must be greater than or equal to the number of redundant experts {num_redundancy_expert}"
|
||||
)
|
||||
|
||||
# Number of experts deployed on each card includes one redundant expert
|
||||
global_deployment: list[list[list[int]]] = [[[]
|
||||
for _ in range(num_npus)]
|
||||
for _ in range(layer_num)]
|
||||
# Iterate to obtain the placement strategy for each layer, taking computational balance into account
|
||||
max_heat_per_layer_after = np.zeros([layer_num])
|
||||
for layer in range(layer_num):
|
||||
# Get the expert IDs and their corresponding workloads for the current layer;
|
||||
# workloads need to be normalized, and one redundant expert is added per card
|
||||
weights = np.zeros((expert_num, ), dtype='object')
|
||||
for expert_id, workload_weight in enumerate(
|
||||
layer_workloads[layer]):
|
||||
weights[expert_id] = (expert_id, workload_weight)
|
||||
|
||||
# Obtain the globally balanced placement strategy for each layer
|
||||
result, layer_deployment = self.original_compute_balanced_pack_redundancy(
|
||||
weights, num_npus, num_redundancy_expert)
|
||||
|
||||
global_deployment[layer] = layer_deployment
|
||||
max_heat_per_layer_after[layer] = max(
|
||||
result, key=lambda x: x['total_weight'])['total_weight']
|
||||
|
||||
new_global_deployment = self.constraint_expert_local_exchange(
|
||||
current_expert_table, global_deployment)
|
||||
# Obtain the priority of each layer
|
||||
layer_changed_ratio = []
|
||||
for layer_idx in range(layer_num):
|
||||
layer_changed_ratio.append(max_heat_per_layer_after[layer_idx] /
|
||||
max_heat_per_layer_before[layer_idx])
|
||||
|
||||
per_layer_priority = np.argsort(layer_changed_ratio)
|
||||
npu_heat_all_after = sum(max_heat_per_layer_after)
|
||||
|
||||
change = 0
|
||||
if npu_heat_all_after < 0.95 * npu_heat_all_origin:
|
||||
change = 1
|
||||
|
||||
return change, per_layer_priority, np.array(
|
||||
new_global_deployment).tolist()
|
||||
771
vllm_ascend/eplb/core/policy/policy_dynamic_ep_v2.py
Normal file
771
vllm_ascend/eplb/core/policy/policy_dynamic_ep_v2.py
Normal file
@@ -0,0 +1,771 @@
|
||||
# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
|
||||
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy.
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DynamicConfig:
|
||||
placement_policy = None
|
||||
|
||||
max_transferred_expert_per_layer = 100 # Maximum number of experts that can be migrated per layer on a single host
|
||||
ep_worldsize = 64 # Total number of dies across the entire cluster where experts are distributed
|
||||
num_die_per_host = 8 # Number of dies on each host machine
|
||||
|
||||
|
||||
class EplbPolicy:
|
||||
|
||||
def __init__(self, config: DynamicConfig):
|
||||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
def rebalance_experts(self, current_expert_table, expert_workload):
|
||||
"""
|
||||
Pass in the weights and return expert replication and placement under relevant constraints.
|
||||
INPUT:
|
||||
current_expert_table: [layerId, rankId, expert_num_i]
|
||||
expert_workload = expert_table[layer0][rankId][expert_num_i]
|
||||
|
||||
RETURNED: (res, expert_table)
|
||||
res:
|
||||
1 -- table_changed
|
||||
0 -- not_changed
|
||||
|
||||
expert_table: [layerId, rankId, expert_num_i]
|
||||
expert_num_i --- [0, MaxExpertPerRank]
|
||||
expertID = expert_table[layer0][rankId][expert_num_i]
|
||||
array_values:
|
||||
[0, 1, 2, 3, 248]
|
||||
[4, 5, 6, 7, 254]
|
||||
[8, 9, 10, 11, 71]
|
||||
...
|
||||
[252, 253, 254, 255, 0]
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DynamicTable:
|
||||
# workload_table:
|
||||
# 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: workload (heat) at the corresponding position
|
||||
# Size: number of layers * number of GPUs * number of experts per GPU per layer
|
||||
# The element at (i, j, k) represents the workload (heat) of the k-th expert on the j-th GPU in the i-th layer
|
||||
# For experts that are not available or collected, the value is set to -1
|
||||
workload_table = None
|
||||
|
||||
# placement_table:
|
||||
# 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: physical expert ID at the corresponding position
|
||||
# Size: number of layers * number of GPUs * number of experts per GPU per layer
|
||||
# The element at (i, j, k) represents the physical expert ID of the k-th expert on the j-th GPU in the i-th layer
|
||||
# For experts that are not available or collected, the value is set to -1
|
||||
placement_table = None
|
||||
|
||||
|
||||
class DynamicEplbV2(EplbPolicy):
|
||||
|
||||
def __init__(self, config: DynamicConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@staticmethod
|
||||
def safe_divide(a, b):
|
||||
if b == 0:
|
||||
print("Division by zero is not allowed")
|
||||
return 0
|
||||
return a / b
|
||||
|
||||
@staticmethod
|
||||
def safe_exact_divide(a, b):
|
||||
if b == 0:
|
||||
print("Division by zero is not allowed")
|
||||
return 0
|
||||
return a // b
|
||||
|
||||
@staticmethod
|
||||
def safe_mod(a, b):
|
||||
if b == 0:
|
||||
print("Division by zero is not allowed")
|
||||
return 0
|
||||
return a % b
|
||||
|
||||
@staticmethod
|
||||
def add_redundant(current_expert_table, expert_workload,
|
||||
num_original_expert):
|
||||
layer_num, npu_num, experts_per_npu = expert_workload.shape
|
||||
workload_new = np.zeros((layer_num, num_original_expert))
|
||||
for layer_idx in range(layer_num):
|
||||
workload_dict: dict[int, int] = defaultdict(int)
|
||||
placement_layer = current_expert_table[layer_idx].copy()
|
||||
workload_layer = expert_workload[layer_idx].copy()
|
||||
for npu_idx in range(npu_num):
|
||||
for expert_idx in range(experts_per_npu):
|
||||
workload_dict[placement_layer[npu_idx][
|
||||
expert_idx]] += workload_layer[npu_idx][expert_idx]
|
||||
for expert_idx in range(num_original_expert):
|
||||
workload_new[layer_idx][expert_idx] = workload_dict[expert_idx]
|
||||
return workload_new
|
||||
|
||||
@staticmethod
|
||||
def get_redundant_num(npu_num, counts):
|
||||
redundant_num_each_npu: int = int(np.sum(counts - 1))
|
||||
return redundant_num_each_npu
|
||||
|
||||
@staticmethod
|
||||
def calculate_max_heat_per_layer(workload_table, layer_num):
|
||||
max_heat_per_layer: list[float] = []
|
||||
for layer_idx in range(layer_num):
|
||||
npu_heats_now = np.sum(workload_table[layer_idx], axis=1)
|
||||
max_heat_per_layer.append(np.max(npu_heats_now))
|
||||
return max_heat_per_layer
|
||||
|
||||
def calculate_initial_imbalance(self, global_deployment,
|
||||
new_layer_workloads):
|
||||
|
||||
device_num = global_deployment.shape[1]
|
||||
layer_imbalance = []
|
||||
expert_num = np.zeros_like(new_layer_workloads)
|
||||
for layer_id, layer in enumerate(global_deployment):
|
||||
for device in layer:
|
||||
for expert_id in device:
|
||||
expert_num[layer_id][expert_id] += 1
|
||||
|
||||
for layer_id, layer in enumerate(global_deployment):
|
||||
cur_layer_max_workload = 0
|
||||
total_workload = 0
|
||||
for box in layer:
|
||||
box_workload = 0
|
||||
for expert_id in box:
|
||||
update_workload = self.safe_divide(
|
||||
new_layer_workloads[layer_id][expert_id],
|
||||
expert_num[layer_id][expert_id])
|
||||
box_workload += update_workload
|
||||
total_workload += update_workload
|
||||
if cur_layer_max_workload < box_workload:
|
||||
cur_layer_max_workload = box_workload
|
||||
|
||||
cur_layer_imbalance = self.safe_divide(
|
||||
cur_layer_max_workload,
|
||||
(self.safe_divide(total_workload, device_num)))
|
||||
layer_imbalance.append(cur_layer_imbalance)
|
||||
|
||||
return layer_imbalance
|
||||
|
||||
def compute_redundant_assignments(self, base_experts,
|
||||
num_redundant_experts, num_experts):
|
||||
|
||||
redundant_assignments: list[list[int]] = [[]
|
||||
for _ in range(num_experts)]
|
||||
current_weights = base_experts.copy()
|
||||
|
||||
for i in range(num_redundant_experts):
|
||||
sorted_indices = np.argsort([w for _, w in current_weights],
|
||||
kind='stable')[::-1]
|
||||
sorted_weights = [current_weights[i] for i in sorted_indices]
|
||||
|
||||
target_expert = sorted_weights[0]
|
||||
expert_id, original_weight = target_expert
|
||||
|
||||
current_redundancy = len(redundant_assignments[expert_id])
|
||||
new_avg_weight = self.safe_divide(
|
||||
original_weight * (current_redundancy + 1),
|
||||
(current_redundancy + 2))
|
||||
|
||||
redundant_assignments[expert_id].append(num_experts + i)
|
||||
current_weights[sorted_indices[0]] = (expert_id, new_avg_weight)
|
||||
|
||||
sorted_indices = np.argsort([w for _, w in current_weights],
|
||||
kind='stable')[::-1]
|
||||
sorted_weights = [current_weights[i] for i in sorted_indices]
|
||||
|
||||
return redundant_assignments, sorted_weights
|
||||
|
||||
def repeat_compute_redundant_assignments(self, layer_workloads, rendun_pos,
|
||||
num_experts, num_exist_expert,
|
||||
device_assignments, device_counts,
|
||||
expert_from_device,
|
||||
com_between_devices):
|
||||
|
||||
current_weights = np.zeros((num_experts, ), dtype='object')
|
||||
for expert_id, workload_weight in enumerate(layer_workloads):
|
||||
current_weights[expert_id] = (expert_id, workload_weight)
|
||||
|
||||
devices_with_slots = []
|
||||
for device_id, device_rendun_pos in enumerate(rendun_pos):
|
||||
if len(device_rendun_pos) != 0:
|
||||
devices_with_slots.append(device_id)
|
||||
|
||||
while devices_with_slots:
|
||||
sorted_indices = np.argsort([w for _, w in current_weights],
|
||||
kind='stable')[::-1]
|
||||
sorted_weights = [current_weights[i] for i in sorted_indices]
|
||||
|
||||
for index, target_weight in enumerate(sorted_weights):
|
||||
expert_id, original_weight = target_weight
|
||||
if original_weight == -1:
|
||||
print("Error:Redundant expert failure re-occurred")
|
||||
redundancy_successful = True
|
||||
break
|
||||
redundancy_successful = False
|
||||
for cur_device_id in devices_with_slots:
|
||||
if expert_id not in device_assignments[cur_device_id]:
|
||||
pos = rendun_pos[cur_device_id].pop()
|
||||
if len(rendun_pos[cur_device_id]) == 0:
|
||||
devices_with_slots = [
|
||||
device_id for device_id in devices_with_slots
|
||||
if device_id != cur_device_id
|
||||
]
|
||||
device_assignments[cur_device_id][pos] = expert_id
|
||||
device_counts[cur_device_id] += 1
|
||||
communication_box_index = expert_from_device[expert_id]
|
||||
com_between_devices[cur_device_id][
|
||||
communication_box_index] = expert_id
|
||||
new_weight = self.safe_divide(
|
||||
(original_weight * num_exist_expert[expert_id]),
|
||||
(num_exist_expert[expert_id] + 1))
|
||||
sorted_weights[index] = (expert_id, new_weight)
|
||||
num_exist_expert[expert_id] += 1
|
||||
redundancy_successful = True
|
||||
break
|
||||
if redundancy_successful:
|
||||
break
|
||||
|
||||
sorted_indices = np.argsort([id for id, _ in sorted_weights],
|
||||
kind='stable')
|
||||
sorted_weights = [sorted_weights[i][1] for i in sorted_indices]
|
||||
|
||||
return sorted_weights, device_assignments, device_counts, com_between_devices
|
||||
|
||||
@staticmethod
|
||||
def prepare_expert_list(base_experts, redundant_assignments,
|
||||
num_redundant_experts):
|
||||
redundant_expert_list = np.empty(num_redundant_experts, dtype=object)
|
||||
|
||||
index = 0
|
||||
num_experts = len(redundant_assignments)
|
||||
for expert_id in range(num_experts):
|
||||
for _ in redundant_assignments[expert_id]:
|
||||
redundant_expert_list[index] = (expert_id,
|
||||
next(w
|
||||
for eid, w in base_experts
|
||||
if eid == expert_id))
|
||||
index += 1
|
||||
|
||||
sorted_indices = np.argsort([w for _, w in redundant_expert_list],
|
||||
kind='stable')[::-1]
|
||||
return [redundant_expert_list[i] for i in sorted_indices]
|
||||
|
||||
@staticmethod
|
||||
def non_redundant_expert_information(origin_deployment, updated_weights,
|
||||
rendun_pos):
|
||||
|
||||
device_num = len(origin_deployment)
|
||||
num_experts_per_device = origin_deployment.shape[1]
|
||||
device_assignments = [[-1 for _ in range(num_experts_per_device)]
|
||||
for _ in range(device_num)]
|
||||
device_weights = [[0 for _ in range(num_experts_per_device)]
|
||||
for _ in range(device_num)]
|
||||
device_loads = [0] * device_num
|
||||
device_counts = [0] * device_num
|
||||
|
||||
for device_id, device in enumerate(origin_deployment):
|
||||
for index, expert_id in enumerate(device):
|
||||
if index in rendun_pos[device_id]:
|
||||
continue
|
||||
device_assignments[device_id][index] = expert_id
|
||||
cur_weight = next(
|
||||
weight for expert_id_of_weight, weight in updated_weights
|
||||
if expert_id_of_weight == expert_id)
|
||||
device_weights[device_id][index] = cur_weight
|
||||
device_loads[device_id] += cur_weight
|
||||
device_counts[device_id] += 1
|
||||
|
||||
return device_assignments, device_weights, device_loads, device_counts
|
||||
|
||||
def recomputing_initial_weight(self, layer_workloads, device_assignments):
|
||||
num_all_experts = [0] * len(layer_workloads)
|
||||
for device in device_assignments:
|
||||
for expert_id in device:
|
||||
if expert_id != -1:
|
||||
num_all_experts[expert_id] += 1
|
||||
|
||||
cur_layer_workload = []
|
||||
for expert_id, weight in enumerate(layer_workloads):
|
||||
if num_all_experts[expert_id] == 0:
|
||||
cur_layer_workload.append(-1)
|
||||
else:
|
||||
cur_layer_workload.append(
|
||||
self.safe_divide(weight, num_all_experts[expert_id]))
|
||||
|
||||
return cur_layer_workload, num_all_experts
|
||||
|
||||
def distribute_redun_experts(self, layer_workloads, device_assignments,
|
||||
device_weights, device_loads, device_counts,
|
||||
redundant_expert_list, expert_from_device,
|
||||
num_experts, rendun_pos):
|
||||
|
||||
num_devices = len(device_assignments)
|
||||
com_between_devices: list[dict[int,
|
||||
int]] = [{} for _ in range(num_devices)]
|
||||
|
||||
for expert_id, weight in redundant_expert_list:
|
||||
candidate = -1
|
||||
for dev_id in range(num_devices):
|
||||
if len(rendun_pos[dev_id]) == 0:
|
||||
continue
|
||||
if expert_id in device_assignments[dev_id]:
|
||||
continue
|
||||
if candidate == -1 or device_loads[dev_id] < device_loads[
|
||||
candidate]:
|
||||
candidate = dev_id
|
||||
if candidate != -1:
|
||||
pos = rendun_pos[candidate].pop()
|
||||
device_assignments[candidate][pos] = expert_id
|
||||
device_weights[candidate][pos] = weight
|
||||
device_loads[candidate] += weight
|
||||
device_counts[candidate] += 1
|
||||
|
||||
communication_box_index = expert_from_device[expert_id]
|
||||
com_between_devices[candidate][
|
||||
communication_box_index] = expert_id
|
||||
|
||||
if any(sublist for sublist in rendun_pos):
|
||||
cur_layer_workload, num_exist_expert = self.recomputing_initial_weight(
|
||||
layer_workloads, device_assignments)
|
||||
|
||||
update_workload, device_assignments, device_counts, com_between_devices = self.repeat_compute_redundant_assignments(
|
||||
cur_layer_workload, rendun_pos, num_experts, num_exist_expert,
|
||||
device_assignments, device_loads, expert_from_device,
|
||||
com_between_devices)
|
||||
|
||||
device_loads = [0] * len(device_counts)
|
||||
for device_id, device in enumerate(device_assignments):
|
||||
for index, expert_id in enumerate(device):
|
||||
device_weights[device_id][index] = update_workload[
|
||||
expert_id]
|
||||
device_loads[device_id] += update_workload[expert_id]
|
||||
|
||||
return device_assignments, device_weights, device_loads, device_counts, com_between_devices
|
||||
|
||||
def redundancy_again(self, layer_workloads, origin_weights,
|
||||
origin_deployment, expert_from_device, num_node,
|
||||
is_node_redundant, rendun_pos):
|
||||
|
||||
num_experts = len(origin_weights)
|
||||
if is_node_redundant:
|
||||
num_experts = num_experts * num_node
|
||||
|
||||
num_redundant_experts = 0
|
||||
for rank_empty_pos in rendun_pos:
|
||||
num_redundant_experts += len(rank_empty_pos)
|
||||
|
||||
redundant_assignments, updated_weights = self.compute_redundant_assignments(
|
||||
origin_weights, num_redundant_experts, num_experts)
|
||||
|
||||
redundant_expert_list = self.prepare_expert_list(
|
||||
updated_weights, redundant_assignments, num_redundant_experts)
|
||||
|
||||
device_assignments, device_weights, device_loads, device_counts = self.non_redundant_expert_information(
|
||||
origin_deployment, updated_weights, rendun_pos)
|
||||
|
||||
device_assignments, device_weights, device_loads, device_counts, com_between_devices = self.distribute_redun_experts(
|
||||
layer_workloads, device_assignments, device_weights, device_loads,
|
||||
device_counts, redundant_expert_list, expert_from_device,
|
||||
num_experts, rendun_pos)
|
||||
|
||||
return device_assignments, device_weights, device_loads, device_counts, com_between_devices
|
||||
|
||||
@staticmethod
|
||||
def generate_allocation_report(device_assignments, device_weights,
|
||||
device_loads, device_counts):
|
||||
|
||||
report = []
|
||||
max_load = 0.0
|
||||
|
||||
for dev_id in range(len(device_assignments)):
|
||||
current_load = device_loads[dev_id]
|
||||
max_load = max(max_load, current_load)
|
||||
|
||||
report.append({
|
||||
"device_id": dev_id + 1,
|
||||
"assigned_experts": device_assignments[dev_id],
|
||||
"expert_weights": device_weights[dev_id],
|
||||
"total_load": current_load,
|
||||
"expert_count": device_counts[dev_id]
|
||||
})
|
||||
|
||||
return report, max_load
|
||||
|
||||
@staticmethod
|
||||
def exchange_expert(cur_exchange_index, next_exchange_index, cur_device_id,
|
||||
next_device_id, cur_layer_result, com_between_devices):
|
||||
|
||||
cur_device_deployment = cur_layer_result[cur_device_id][
|
||||
'assigned_experts']
|
||||
next_device_deployment = cur_layer_result[next_device_id][
|
||||
'assigned_experts']
|
||||
|
||||
cur_device_weight = cur_layer_result[cur_device_id]['expert_weights']
|
||||
next_device_weight = cur_layer_result[next_device_id]['expert_weights']
|
||||
|
||||
cur_expert_id = cur_device_deployment[cur_exchange_index]
|
||||
next_expert_id = next_device_deployment[next_exchange_index]
|
||||
cur_device_deployment[cur_exchange_index] = next_expert_id
|
||||
next_device_deployment[next_exchange_index] = cur_expert_id
|
||||
|
||||
cur_expert_weight = cur_device_weight[cur_exchange_index]
|
||||
next_expert_weight = next_device_weight[next_exchange_index]
|
||||
cur_device_weight[cur_exchange_index] = next_expert_weight
|
||||
next_device_weight[next_exchange_index] = cur_expert_weight
|
||||
|
||||
cur_layer_result[cur_device_id][
|
||||
'total_load'] += next_expert_weight - cur_expert_weight
|
||||
cur_layer_result[next_device_id][
|
||||
'total_load'] += cur_expert_weight - next_expert_weight
|
||||
|
||||
com_between_devices[cur_device_id][next_device_id] = next_expert_id
|
||||
com_between_devices[next_device_id][cur_device_id] = cur_expert_id
|
||||
|
||||
def redundant_expert_deployment(self, layer_workloads, original_deployment,
|
||||
expert_from_device, node_num,
|
||||
is_node_redundant, rendun_pos):
|
||||
device_num, per_device_expert_num = original_deployment.shape
|
||||
route_expert_num = layer_workloads.shape[0]
|
||||
per_node_device_num = self.safe_exact_divide(device_num, node_num)
|
||||
per_node_route_expert_num = per_node_device_num * (
|
||||
per_device_expert_num - 1)
|
||||
|
||||
weights = np.zeros((route_expert_num, ), dtype='object')
|
||||
for expert_id, workload_weight in enumerate(layer_workloads):
|
||||
weights[expert_id] = (expert_id, workload_weight)
|
||||
|
||||
if is_node_redundant:
|
||||
|
||||
device_assignments = []
|
||||
device_weights = []
|
||||
device_loads = []
|
||||
device_counts = []
|
||||
com_between_devices = []
|
||||
|
||||
for node_id in range(node_num):
|
||||
cur_node_weights = weights[node_id *
|
||||
per_node_route_expert_num:(node_id +
|
||||
1) *
|
||||
per_node_route_expert_num]
|
||||
cur_original_deployment = original_deployment[
|
||||
node_id * per_node_device_num:(node_id + 1) *
|
||||
per_node_device_num]
|
||||
|
||||
cur_node_rendun_pos = rendun_pos[node_id *
|
||||
per_node_device_num:(node_id +
|
||||
1) *
|
||||
per_node_device_num]
|
||||
|
||||
cur_device_assignments, cur_device_weights, cur_device_loads, cur_device_counts, cur_com_between_devices = self.redundancy_again(
|
||||
layer_workloads, cur_node_weights, cur_original_deployment,
|
||||
expert_from_device, node_num, is_node_redundant,
|
||||
cur_node_rendun_pos)
|
||||
device_assignments += cur_device_assignments
|
||||
device_weights += cur_device_weights
|
||||
device_loads += cur_device_loads
|
||||
device_counts += cur_device_counts
|
||||
com_between_devices += cur_com_between_devices
|
||||
|
||||
else:
|
||||
device_assignments, device_weights, device_loads, device_counts, com_between_devices = self.redundancy_again(
|
||||
layer_workloads, weights, original_deployment,
|
||||
expert_from_device, node_num, is_node_redundant, rendun_pos)
|
||||
report, max_load = self.generate_allocation_report(
|
||||
device_assignments, device_weights, device_loads, device_counts)
|
||||
|
||||
return report, max_load, com_between_devices
|
||||
|
||||
@staticmethod
|
||||
def two_device_exchange_experts(cur_device_result, exchange_device_result,
|
||||
cur_exchanged_expert_id,
|
||||
next_exchanged_expert_id, ave_workload,
|
||||
increment, num_redundancy_expert):
|
||||
|
||||
cur_device_weight = cur_device_result['expert_weights']
|
||||
next_device_weight = exchange_device_result['expert_weights']
|
||||
|
||||
cur_device_expert_id = cur_device_result['assigned_experts']
|
||||
next_device_expert_id = exchange_device_result['assigned_experts']
|
||||
|
||||
cur_device_total_weight = cur_device_result['total_load']
|
||||
next_device_total_weight = exchange_device_result['total_load']
|
||||
max_weight = max(cur_device_total_weight, next_device_total_weight)
|
||||
|
||||
cur_exchange_index = -1
|
||||
next_exchange_index = -1
|
||||
|
||||
for index, weight in enumerate(cur_device_weight):
|
||||
for next_index, next_weight in enumerate(next_device_weight):
|
||||
change_flag = True
|
||||
if (cur_device_expert_id[index] in next_device_expert_id
|
||||
or next_device_expert_id[next_index]
|
||||
in cur_device_expert_id):
|
||||
change_flag = False
|
||||
if (cur_device_expert_id[index] not in cur_exchanged_expert_id
|
||||
) and (next_device_expert_id[next_index]
|
||||
not in next_exchanged_expert_id) and change_flag:
|
||||
|
||||
cur_total_weight_after_exchange = cur_device_total_weight - weight + next_weight
|
||||
next_total_weight_after_exchange = next_device_total_weight - next_weight + weight
|
||||
exchange_max_weight = max(
|
||||
cur_total_weight_after_exchange,
|
||||
next_total_weight_after_exchange)
|
||||
if exchange_max_weight < max_weight and (
|
||||
max_weight -
|
||||
exchange_max_weight) >= (ave_workload * increment):
|
||||
max_weight = exchange_max_weight
|
||||
cur_exchange_index = index
|
||||
next_exchange_index = next_index
|
||||
|
||||
return cur_exchange_index, next_exchange_index
|
||||
|
||||
def expert_exchange_between_devices(self,
|
||||
ave_workload,
|
||||
increment,
|
||||
cur_layer_result,
|
||||
com_between_devices,
|
||||
num_redundancy_expert,
|
||||
node_idx=0,
|
||||
per_node_device_num=0,
|
||||
is_node_redundant=False):
|
||||
|
||||
if is_node_redundant:
|
||||
cur_devices_result = cur_layer_result[node_idx *
|
||||
per_node_device_num:
|
||||
(node_idx + 1) *
|
||||
per_node_device_num]
|
||||
else:
|
||||
cur_devices_result = cur_layer_result
|
||||
|
||||
devices_total_weight = []
|
||||
for device in cur_devices_result:
|
||||
devices_total_weight.append(
|
||||
(device['total_load'], device['device_id'] - 1))
|
||||
|
||||
exchange_frequency = 100
|
||||
while exchange_frequency > 0:
|
||||
exchange_frequency -= 1
|
||||
devices_total_weight.sort(key=lambda x: x[0])
|
||||
max_weight_device_id = devices_total_weight[-1][1]
|
||||
exchange = False
|
||||
for index in range(0, len(devices_total_weight) - 1):
|
||||
min_weight_device_id = devices_total_weight[index][1]
|
||||
if min_weight_device_id not in com_between_devices[
|
||||
max_weight_device_id]:
|
||||
cur_exchanged_expert_id = list(
|
||||
com_between_devices[max_weight_device_id].values())
|
||||
next_exchanged_expert_id = list(
|
||||
com_between_devices[min_weight_device_id].values())
|
||||
|
||||
cur_exchange_index, next_exchange_index = self.two_device_exchange_experts(
|
||||
cur_layer_result[max_weight_device_id],
|
||||
cur_layer_result[min_weight_device_id],
|
||||
cur_exchanged_expert_id, next_exchanged_expert_id,
|
||||
ave_workload, increment, num_redundancy_expert)
|
||||
|
||||
if cur_exchange_index != -1:
|
||||
self.exchange_expert(cur_exchange_index,
|
||||
next_exchange_index,
|
||||
max_weight_device_id,
|
||||
min_weight_device_id,
|
||||
cur_layer_result,
|
||||
com_between_devices)
|
||||
|
||||
devices_total_weight[-1] = (
|
||||
cur_layer_result[max_weight_device_id]
|
||||
['total_load'], max_weight_device_id)
|
||||
devices_total_weight[index] = (
|
||||
cur_layer_result[min_weight_device_id]
|
||||
['total_load'], min_weight_device_id)
|
||||
exchange = True
|
||||
break
|
||||
|
||||
if not exchange:
|
||||
break
|
||||
|
||||
def exchange_experts(self, layer_result, layer_com_between_devices,
|
||||
num_nodes, device_num, is_node_redundant,
|
||||
ave_workload, increment, num_redundancy_expert,
|
||||
org_deployment):
|
||||
|
||||
global_deployment = []
|
||||
|
||||
if is_node_redundant:
|
||||
per_node_device_num = self.safe_exact_divide(device_num, num_nodes)
|
||||
for node_idx in range(num_nodes):
|
||||
self.expert_exchange_between_devices(
|
||||
ave_workload, increment, layer_result,
|
||||
layer_com_between_devices, num_redundancy_expert, node_idx,
|
||||
per_node_device_num, is_node_redundant)
|
||||
else:
|
||||
self.expert_exchange_between_devices(ave_workload, increment,
|
||||
layer_result,
|
||||
layer_com_between_devices,
|
||||
num_redundancy_expert)
|
||||
|
||||
max_workload = 0
|
||||
for box in layer_result:
|
||||
global_deployment.append(box['assigned_experts'])
|
||||
if max_workload < box['total_load']:
|
||||
max_workload = box['total_load']
|
||||
|
||||
global_deployment = np.array(global_deployment)
|
||||
|
||||
return global_deployment, max_workload
|
||||
|
||||
def count_elements(self, lst):
|
||||
count = 0
|
||||
for item in lst:
|
||||
if isinstance(item, list):
|
||||
count += self.count_elements(item)
|
||||
else:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@staticmethod
|
||||
def constraint_expert_local_exchange(current_expert_table,
|
||||
global_deployment):
|
||||
for layer_id in range(len(global_deployment)):
|
||||
for card_id in range(len(global_deployment[layer_id])):
|
||||
current_list = [
|
||||
int(x) for x in current_expert_table[layer_id][card_id]
|
||||
]
|
||||
new_list = [
|
||||
int(x) for x in global_deployment[layer_id][card_id]
|
||||
]
|
||||
num = len(new_list)
|
||||
|
||||
new_index = [-1] * num
|
||||
new_result = [-1] * num
|
||||
remaining_elements = []
|
||||
|
||||
for i in range(num):
|
||||
flag = True
|
||||
for j in range(num):
|
||||
if new_list[i] == current_list[j] and new_index[
|
||||
j] == -1:
|
||||
new_index[j] = 0
|
||||
new_result[j] = current_list[j]
|
||||
flag = False
|
||||
break
|
||||
if flag:
|
||||
remaining_elements.append(new_list[i])
|
||||
|
||||
index = 0
|
||||
for k in range(num):
|
||||
if new_result[k] == -1:
|
||||
new_result[k] = remaining_elements[index]
|
||||
index += 1
|
||||
|
||||
global_deployment[layer_id][card_id] = new_result
|
||||
|
||||
return global_deployment
|
||||
|
||||
def rebalance_experts(self,
|
||||
current_expert_table,
|
||||
expert_workload,
|
||||
is_node_redundant=False,
|
||||
increment=0.01):
|
||||
info = DynamicTable()
|
||||
info.workload_table = expert_workload.numpy()
|
||||
info.placement_table = current_expert_table.numpy()
|
||||
assert info.workload_table is not None
|
||||
layer_num, num_npus, experts_per_npu = info.workload_table.shape
|
||||
expert_ids, counts = np.unique(info.placement_table[0],
|
||||
return_counts=True)
|
||||
num_redundancy_expert = self.get_redundant_num(num_npus, counts)
|
||||
num_original_expert = len(expert_ids)
|
||||
layer_workloads = self.add_redundant(info.placement_table,
|
||||
info.workload_table,
|
||||
num_original_expert)
|
||||
max_heat_per_layer_before = self.calculate_max_heat_per_layer(
|
||||
info.workload_table, layer_num)
|
||||
npu_heat_all_origin = sum(max_heat_per_layer_before)
|
||||
|
||||
num_node = self.safe_exact_divide(num_npus, 8)
|
||||
layer_num = layer_workloads.shape[0]
|
||||
expert_num = layer_workloads.shape[1]
|
||||
expert_from_device = np.zeros((layer_num, num_original_expert))
|
||||
|
||||
if num_original_expert != expert_num:
|
||||
raise ValueError(
|
||||
f"The number of original experts ({num_original_expert}) must match expert_num ({expert_num})"
|
||||
)
|
||||
|
||||
if num_npus <= 0:
|
||||
raise ValueError("The number of NPUs must be greater than 0")
|
||||
|
||||
if num_npus < num_redundancy_expert:
|
||||
raise ValueError(
|
||||
f"The number of NPUs ({num_npus}) must be greater than or equal to the number of redundant experts ({num_redundancy_expert})"
|
||||
)
|
||||
|
||||
global_deployment: list[list[list[int]]] = [[[]
|
||||
for _ in range(num_npus)]
|
||||
for _ in range(layer_num)]
|
||||
layer_initial_imbalance = self.calculate_initial_imbalance(
|
||||
info.placement_table, layer_workloads)
|
||||
max_heat_per_layer_after = np.zeros([layer_num])
|
||||
sum_num = 0
|
||||
for layer in range(layer_num):
|
||||
# print(f"Load imbalance ratio of layer {layer} under the new workload", layer_initial_imbalance[layer])
|
||||
if layer_initial_imbalance[layer] < 1.01:
|
||||
global_deployment[layer] = info.placement_table[layer]
|
||||
continue
|
||||
|
||||
ave_workload = self.safe_divide(np.sum(layer_workloads[layer]),
|
||||
num_npus)
|
||||
|
||||
rendun_pos: list[list[int]] = [[] for _ in range(num_npus)]
|
||||
existing_experts = set()
|
||||
for device_id, device in enumerate(info.placement_table[layer]):
|
||||
for index, expert_id in enumerate(device):
|
||||
if expert_id not in existing_experts:
|
||||
existing_experts.add(expert_id)
|
||||
expert_from_device[layer][expert_id] = device_id
|
||||
else:
|
||||
rendun_pos[device_id].append(index)
|
||||
|
||||
result, max_workload, com_between_devices = self.redundant_expert_deployment(
|
||||
layer_workloads[layer], info.placement_table[layer],
|
||||
expert_from_device[layer], num_node, is_node_redundant,
|
||||
rendun_pos)
|
||||
# print(layer, f"Imbalance Ratio after Redundancy Adjustment:", self.safe_divide(max_workload, ave_workload))
|
||||
|
||||
global_deployment[layer], new_max_workload = self.exchange_experts(
|
||||
result, com_between_devices, num_node, num_npus,
|
||||
is_node_redundant, ave_workload, increment,
|
||||
num_redundancy_expert, info.placement_table[layer])
|
||||
# print(layer, f"Imbalance Ratio after Swap Adjustment:", self.safe_divide(new_max_workload, ave_workload))
|
||||
|
||||
for device_id in range(num_npus):
|
||||
com_between_devices[device_id] = {
|
||||
key: value
|
||||
for key, value in com_between_devices[device_id].items()
|
||||
}
|
||||
sum_num += self.count_elements(com_between_devices[device_id])
|
||||
|
||||
max_heat_per_layer_after[layer] = max(
|
||||
result, key=lambda x: x['total_load'])['total_load']
|
||||
|
||||
layer_changed_ratio = []
|
||||
for layer_idx in range(layer_num):
|
||||
layer_changed_ratio.append(
|
||||
self.safe_divide(max_heat_per_layer_after[layer_idx],
|
||||
max_heat_per_layer_before[layer_idx]))
|
||||
|
||||
per_layer_priority = np.argsort(layer_changed_ratio)
|
||||
npu_heat_all_after = sum(max_heat_per_layer_after)
|
||||
|
||||
change = 0
|
||||
if npu_heat_all_after < 0.95 * npu_heat_all_origin:
|
||||
change = 1
|
||||
|
||||
new_global_deployment = self.constraint_expert_local_exchange(
|
||||
current_expert_table, global_deployment)
|
||||
|
||||
return change, per_layer_priority, np.array(
|
||||
new_global_deployment).tolist()
|
||||
33
vllm_ascend/eplb/core/policy/policy_factory.py
Normal file
33
vllm_ascend/eplb/core/policy/policy_factory.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
|
||||
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this factory.
|
||||
from .policy_abstract import DynamicConfig, EplbPolicy
|
||||
from .policy_dynamic_ep import DynamicEplb
|
||||
from .policy_dynamic_ep_v2 import DynamicEplbV2
|
||||
from .policy_flashlb import FlashLB
|
||||
from .policy_random import RandomLoadBalance
|
||||
|
||||
|
||||
class PolicyFactory:
|
||||
|
||||
@staticmethod
|
||||
def generate_policy(policy_type: int, config: DynamicConfig) -> EplbPolicy:
|
||||
policy = {
|
||||
# Constraint applying Dynamic EPLB policy V2:
|
||||
# If there exists redundant expert:
|
||||
# only one redundant expert can be placed in one NPU and its physical expert index must be 0
|
||||
|
||||
# Applying greedy d2d expert weight update composing
|
||||
0:
|
||||
RandomLoadBalance, # RandomLoadBalance: shuffle last physical expert on NPU 1 and 3
|
||||
1:
|
||||
DynamicEplb, # Dynamic EPLB policy: overall expert replacement based on current moe load
|
||||
2:
|
||||
DynamicEplbV2, # Dynamic EPLB policy V2: expert replacement with constrained number of expert shuffle
|
||||
3:
|
||||
FlashLB, # FlashLB EPLB policy: expert replacement based on Joint Optimization, Multi-Shot Enhancement and Incremental Adjustment
|
||||
}
|
||||
policy_class = policy.get(policy_type, RandomLoadBalance)
|
||||
policy_instance = policy_class(config)
|
||||
if policy_type == 3:
|
||||
policy_instance.warm_up()
|
||||
return policy_instance
|
||||
651
vllm_ascend/eplb/core/policy/policy_flashlb.py
Normal file
651
vllm_ascend/eplb/core/policy/policy_flashlb.py
Normal file
@@ -0,0 +1,651 @@
|
||||
# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
|
||||
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy.
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from numba import njit # type: ignore
|
||||
|
||||
from .policy_abstract import DynamicConfig, EplbPolicy
|
||||
|
||||
numba_logger = logging.getLogger("numba")
|
||||
numba_logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
@njit
|
||||
def compute_piece_counts(X, P, stage_weights):
|
||||
n_stage, N = X.shape
|
||||
S = P - N
|
||||
pieces = np.ones(N, dtype=np.int32)
|
||||
unit = X / pieces # unit[i, j] = X[i, j] / pieces[j]
|
||||
|
||||
for _ in range(S):
|
||||
deltas = np.zeros(N, dtype=np.float32)
|
||||
for i in range(n_stage):
|
||||
# Find top1 and top2
|
||||
idx1 = -1
|
||||
idx2 = -1
|
||||
val1 = -1.0
|
||||
val2 = -1.0
|
||||
for j in range(N):
|
||||
v = unit[i, j]
|
||||
if v > val1:
|
||||
val2 = val1
|
||||
idx2 = idx1
|
||||
val1 = v
|
||||
idx1 = j
|
||||
elif v > val2:
|
||||
val2 = v
|
||||
idx2 = j
|
||||
|
||||
origin = unit[i, idx1]
|
||||
secv = unit[i, idx2]
|
||||
alt = X[i, idx1] / (pieces[idx1] + 1)
|
||||
delta = origin - (alt if alt > secv else secv)
|
||||
deltas[idx1] += delta * stage_weights[i] if np.any(
|
||||
delta) != 0 else stage_weights[i]
|
||||
|
||||
max_idx = np.argmax(deltas)
|
||||
pieces[max_idx] += 1
|
||||
for i in range(n_stage):
|
||||
unit[i, max_idx] = X[i, max_idx] / pieces[max_idx]
|
||||
|
||||
# Compute max load
|
||||
max_load = 0.0
|
||||
for j in range(N):
|
||||
total = 0.0
|
||||
for i in range(n_stage):
|
||||
total += unit[i, j]
|
||||
if total > max_load:
|
||||
max_load = total
|
||||
|
||||
return pieces
|
||||
|
||||
|
||||
@njit
|
||||
def jsq_placement(X, pieces, M, stage_weights):
|
||||
n_stage, N = X.shape
|
||||
total_piece = pieces.sum()
|
||||
num_per_group = total_piece // M
|
||||
|
||||
# 1. Compute unit_hotness
|
||||
unit_hotness = np.empty((n_stage, N), dtype=np.float32)
|
||||
for i in range(N):
|
||||
if pieces[i] > 0:
|
||||
for s in range(n_stage):
|
||||
unit_hotness[s, i] = X[s, i] / pieces[i]
|
||||
else:
|
||||
for s in range(n_stage):
|
||||
unit_hotness[s, i] = 0.0
|
||||
|
||||
# 2. Sort by total hotness
|
||||
scores = np.zeros(N, dtype=np.float32)
|
||||
for i in range(N):
|
||||
for s in range(n_stage):
|
||||
scores[i] += unit_hotness[s, i]
|
||||
idx = np.argsort(-scores)
|
||||
|
||||
# 3. Initialization
|
||||
loads = np.zeros((n_stage, M), dtype=np.float32)
|
||||
dev_phy_exp_n = np.zeros(M, dtype=np.int32)
|
||||
deployment = -np.ones((M, num_per_group), dtype=np.int32)
|
||||
dep_ptr = np.zeros(M, dtype=np.int32)
|
||||
|
||||
# 4. Main loop
|
||||
for t in range(N):
|
||||
i = idx[t]
|
||||
used_device = list()
|
||||
for _ in range(pieces[i]):
|
||||
# 4.1 Construct w vector
|
||||
w = np.empty(n_stage, dtype=np.float32)
|
||||
for s in range(n_stage):
|
||||
w[s] = unit_hotness[s, i]
|
||||
|
||||
# 4.2 Compute stage-level maximum load
|
||||
stage_max = np.empty(n_stage, dtype=np.float32)
|
||||
for s in range(n_stage):
|
||||
max_val = loads[s, 0]
|
||||
for k in range(1, M):
|
||||
if loads[s, k] > max_val:
|
||||
max_val = loads[s, k]
|
||||
stage_max[s] = max_val
|
||||
|
||||
# 4.3 Compute denominator
|
||||
denom = np.empty(n_stage, dtype=np.float32)
|
||||
for s in range(n_stage):
|
||||
sum_tmp = 0.0
|
||||
for j in range(M):
|
||||
sum_tmp += loads[s, j] + w[s]
|
||||
denom[s] = sum_tmp / M + 1e-2
|
||||
|
||||
# 4.4 Find best device j
|
||||
best_j = -1
|
||||
best_val = 1e30
|
||||
for j in range(M):
|
||||
if dev_phy_exp_n[j] >= num_per_group:
|
||||
continue
|
||||
if j in used_device:
|
||||
continue
|
||||
score = 0.0
|
||||
for s in range(n_stage):
|
||||
tmp_sj = loads[s, j] + w[s]
|
||||
numer_sj = tmp_sj if tmp_sj > stage_max[s] else stage_max[s]
|
||||
score += stage_weights[s] * (numer_sj / denom[s])
|
||||
if score < best_val:
|
||||
best_val = score
|
||||
best_j = j
|
||||
if best_j == -1:
|
||||
continue
|
||||
|
||||
used_device.append(best_j)
|
||||
|
||||
# 4.5 Update status
|
||||
for s in range(n_stage):
|
||||
loads[s, best_j] += w[s]
|
||||
ptr = dep_ptr[best_j]
|
||||
deployment[best_j, ptr] = i
|
||||
dep_ptr[best_j] += 1
|
||||
dev_phy_exp_n[best_j] += 1
|
||||
|
||||
# Handle remaining -1 values: fill with random elements from range(N) not in current column
|
||||
for rank in range(M):
|
||||
for col in range(num_per_group):
|
||||
if deployment[rank, col] == -1:
|
||||
# Get elements already in current column
|
||||
current_rank_elements = set(deployment[rank, :])
|
||||
# Filter elements from range(N) not in current column
|
||||
available = [
|
||||
x for x in range(N) if x not in current_rank_elements
|
||||
]
|
||||
# Randomly select an available element to fill
|
||||
if len(available) > 0:
|
||||
rand_idx = np.random.randint(0, len(available))
|
||||
deployment[rank, col] = available[rand_idx]
|
||||
elif N > 0:
|
||||
# All unique experts are already in this rank's column, so we can pick any expert randomly.
|
||||
deployment[rank, col] = np.random.randint(0, N)
|
||||
|
||||
return deployment
|
||||
|
||||
|
||||
@njit
|
||||
def slice_values(X, pieces):
|
||||
total_len = 0
|
||||
for i in range(X.shape[0]):
|
||||
total_len += pieces[i]
|
||||
result = np.empty(total_len, dtype=np.float32)
|
||||
idx = 0
|
||||
for i in range(X.shape[0]):
|
||||
val = X[i] / pieces[i]
|
||||
for _ in range(pieces[i]):
|
||||
result[idx] = val
|
||||
idx += 1
|
||||
return result
|
||||
|
||||
|
||||
@njit
|
||||
def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces,
|
||||
simulated_deployment, stage_weights):
|
||||
n_stage, N = X.shape
|
||||
num_group = P // M
|
||||
|
||||
X_all = np.zeros(N, dtype=np.float32)
|
||||
for i in range(n_stage):
|
||||
for j in range(N):
|
||||
X_all[j] += X[i, j]
|
||||
|
||||
sort_idx = np.argsort(np.negative(X_all))
|
||||
X_sorted = X[:, sort_idx]
|
||||
|
||||
unit_load = np.empty(N, dtype=np.float32)
|
||||
for j in range(N):
|
||||
unit_load[j] = X_all[j] / simulated_pieces[j]
|
||||
|
||||
flat_deployment = simulated_deployment.reshape(-1)
|
||||
simulated_load = np.zeros(M, dtype=np.float32)
|
||||
for i in range(flat_deployment.shape[0]):
|
||||
simulated_load[i // (flat_deployment.shape[0] //
|
||||
M)] += unit_load[flat_deployment[i]]
|
||||
|
||||
slice_vals = slice_values(X_all, simulated_pieces)
|
||||
sorted_slices = np.sort(slice_vals)[::-1]
|
||||
simulated_slopes = (sorted_slices[:-M + 1] - sorted_slices[M - 1:]) / M
|
||||
|
||||
cumulative_slices_used = np.zeros(N, dtype=np.int32)
|
||||
acc = 0
|
||||
for i in range(N):
|
||||
acc += simulated_pieces[sort_idx[i]]
|
||||
cumulative_slices_used[i] = acc
|
||||
|
||||
group_boundary_indices = np.zeros(num_group, dtype=np.int32)
|
||||
for i in range(1, num_group + 1):
|
||||
for j in range(N):
|
||||
if cumulative_slices_used[j] >= i * M:
|
||||
group_boundary_indices[i - 1] = j
|
||||
break
|
||||
|
||||
slices_used_per_group = np.zeros(num_group, dtype=np.int32)
|
||||
slices_used_per_group[0] = group_boundary_indices[0]
|
||||
for i in range(1, num_group):
|
||||
slices_used_per_group[
|
||||
i] = group_boundary_indices[i] - group_boundary_indices[i - 1]
|
||||
slices_used_per_group = M - slices_used_per_group
|
||||
|
||||
loads = np.zeros(M, dtype=np.float32)
|
||||
pieces = np.zeros(N, dtype=np.int32)
|
||||
num_remain_slice = P - N
|
||||
current_idx = 0
|
||||
|
||||
for g in range(num_group):
|
||||
window = X_sorted[:, current_idx:current_idx + 2 * M]
|
||||
low = max(0, current_idx + M - N)
|
||||
high = min(num_remain_slice, M - 1)
|
||||
|
||||
while (high - low) > 1:
|
||||
mid = int((high + low) // 2)
|
||||
keep = M - mid
|
||||
current_group = window[:, :keep]
|
||||
current_pieces = compute_piece_counts(current_group, M,
|
||||
stage_weights)
|
||||
current_pieces = np.maximum(current_pieces, 1)
|
||||
current_slice = slice_values(current_group.sum(0), current_pieces)
|
||||
current_slice_sorted = np.sort(current_slice)
|
||||
current_loads = loads + current_slice_sorted
|
||||
current_max: np.float32 = np.max(current_loads)
|
||||
current_min: np.float32 = np.min(current_loads)
|
||||
current_slope = (current_max - current_min) / M
|
||||
next_slope: np.float32 = np.max(simulated_slopes[current_idx +
|
||||
keep:])
|
||||
|
||||
if abs(current_slope) > abs(next_slope):
|
||||
low = mid
|
||||
else:
|
||||
high = mid
|
||||
|
||||
S = high
|
||||
keep = M - S
|
||||
current_group = window[:, :keep]
|
||||
current_pieces = compute_piece_counts(current_group, M, stage_weights)
|
||||
|
||||
for i in range(keep):
|
||||
pieces[sort_idx[current_idx + i]] = current_pieces[i]
|
||||
|
||||
current_slice = slice_values(current_group.sum(0), current_pieces)
|
||||
current_slice_sorted = np.sort(current_slice)
|
||||
loads += current_slice_sorted
|
||||
loads = np.sort(loads)[::-1]
|
||||
|
||||
current_idx += keep
|
||||
num_remain_slice -= S
|
||||
|
||||
return pieces
|
||||
|
||||
|
||||
@njit
|
||||
def compute_objective(deployment, X, pieces):
|
||||
M, P = deployment.shape
|
||||
loads = np.zeros(M)
|
||||
|
||||
for i in range(M):
|
||||
for j in range(P):
|
||||
expert = deployment[i, j]
|
||||
if pieces[expert] == 0:
|
||||
continue
|
||||
loads[i] += X[expert] / pieces[expert]
|
||||
|
||||
mean_load = np.mean(loads)
|
||||
max_load: np.float32 = np.max(loads)
|
||||
obj = max_load / mean_load
|
||||
return obj, loads
|
||||
|
||||
|
||||
@njit
|
||||
def auto_fix_new_placement(old_placement, new_placement):
|
||||
"""
|
||||
Adjust the new_placement matrix to ensure elements (including duplicates) that exist in both
|
||||
old_placement and new_placement remain in their original positions from old_placement.
|
||||
New elements (unique to new_placement) will fill the remaining empty positions.
|
||||
|
||||
Args:
|
||||
old_placement: Old deployment matrix with shape (num_ranks, num_experts)
|
||||
new_placement: New deployment matrix to be fixed, must have the same shape as old_placement
|
||||
|
||||
Returns:
|
||||
fixed_new: adjusted version of the new_placement matrix
|
||||
"""
|
||||
num_ranks, num_experts = old_placement.shape
|
||||
fixed_new = np.empty_like(new_placement)
|
||||
|
||||
max_expert_old = old_placement.max() if num_experts > 0 else 0
|
||||
max_expert_new = new_placement.max() if num_experts > 0 else 0
|
||||
max_expert = max(max_expert_old, max_expert_new)
|
||||
|
||||
for rank_id in range(num_ranks):
|
||||
old_row = old_placement[rank_id]
|
||||
new_row = new_placement[rank_id]
|
||||
|
||||
index_array = np.full((max_expert + 1, num_experts),
|
||||
-1,
|
||||
dtype=np.int32)
|
||||
count_array = np.zeros(max_expert + 1, dtype=np.int32)
|
||||
|
||||
for idx in range(num_experts):
|
||||
val = old_row[idx]
|
||||
if val >= 0 and val <= max_expert:
|
||||
pos = count_array[val]
|
||||
index_array[val, pos] = idx
|
||||
count_array[val] += 1
|
||||
|
||||
old_counter = np.zeros(max_expert + 1, dtype=np.int32)
|
||||
for idx in range(num_experts):
|
||||
val = old_row[idx]
|
||||
if val >= 0 and val <= max_expert:
|
||||
old_counter[val] += 1
|
||||
|
||||
retain_elements = np.empty(num_experts, dtype=new_placement.dtype)
|
||||
new_elements = np.empty(num_experts, dtype=new_placement.dtype)
|
||||
retain_ptr = 0
|
||||
new_ptr = 0
|
||||
|
||||
for val in new_row:
|
||||
if val >= 0 and val <= max_expert and old_counter[val] > 0:
|
||||
retain_elements[retain_ptr] = val
|
||||
retain_ptr += 1
|
||||
old_counter[val] -= 1
|
||||
else:
|
||||
new_elements[new_ptr] = val
|
||||
new_ptr += 1
|
||||
|
||||
current_fixed = np.full(num_experts, -1, dtype=new_placement.dtype)
|
||||
|
||||
for i in range(retain_ptr):
|
||||
val = retain_elements[i]
|
||||
if val >= 0 and val <= max_expert:
|
||||
pos = count_array[val] - 1
|
||||
if pos >= 0:
|
||||
idx = index_array[val, pos]
|
||||
current_fixed[idx] = val
|
||||
count_array[val] -= 1
|
||||
|
||||
empty_indices = np.empty(num_experts, dtype=np.int32)
|
||||
empty_ptr = 0
|
||||
for idx in range(num_experts):
|
||||
if current_fixed[idx] == -1:
|
||||
empty_indices[empty_ptr] = idx
|
||||
empty_ptr += 1
|
||||
|
||||
for i in range(new_ptr):
|
||||
if i < empty_ptr:
|
||||
current_fixed[empty_indices[i]] = new_elements[i]
|
||||
|
||||
fixed_new[rank_id] = current_fixed
|
||||
|
||||
return fixed_new
|
||||
|
||||
|
||||
class FlashLB(EplbPolicy):
|
||||
|
||||
def __init__(self, config: DynamicConfig):
|
||||
super().__init__(config)
|
||||
self.par_history: Dict[int, float] = {}
|
||||
self.hotness_window: Dict[int, deque[float]] = {}
|
||||
self.max_stage_window = (config.max_stage_window if hasattr(
|
||||
config, "max_stage_window") else 1)
|
||||
self.buffer_expert_layer_num = (
|
||||
config.buffer_expert_layer_num if hasattr(
|
||||
config, "buffer_expert_layer_num") else 58)
|
||||
self.threshold_ratio = (config.threshold_ratio if hasattr(
|
||||
config, "threshold_ratio") else 0)
|
||||
|
||||
def compute_expert_hotness(self, num_of_expert: int,
|
||||
deployment: np.ndarray, rank_load: np.ndarray):
|
||||
hotness = np.zeros(num_of_expert, dtype=rank_load.dtype)
|
||||
deployment_flat = deployment.ravel()
|
||||
rank_load_flat = rank_load.ravel()
|
||||
np.add.at(hotness, deployment_flat, rank_load_flat)
|
||||
return hotness
|
||||
|
||||
def compute_rank_load(self, deployment: np.ndarray, hotness: np.ndarray):
|
||||
n_stage, N = hotness.shape
|
||||
if np.any(deployment < 0):
|
||||
print(f"Invalid deployment with negative values: {deployment}")
|
||||
raise ValueError("Deployment table contains negative values.")
|
||||
counts = np.bincount(deployment.reshape(-1), minlength=N)
|
||||
unit_hotness = np.divide(hotness,
|
||||
counts,
|
||||
out=np.zeros_like(hotness, dtype=float),
|
||||
where=counts != 0)
|
||||
stage_par = np.zeros(n_stage)
|
||||
for i in range(n_stage):
|
||||
stage_load = unit_hotness[i][deployment].sum(-1)
|
||||
stage_par[i] = stage_load.max() / stage_load.mean()
|
||||
return stage_par.mean()
|
||||
|
||||
def group_based_adaptive_bloating(self,
|
||||
X,
|
||||
P,
|
||||
M,
|
||||
stage_weights=None,
|
||||
recorsive=False):
|
||||
n_stage, N = X.shape
|
||||
if stage_weights is None:
|
||||
stage_weights = np.ones(n_stage, dtype=np.float32)
|
||||
|
||||
if recorsive:
|
||||
(
|
||||
simulated_deployment,
|
||||
simulated_pieces,
|
||||
) = self.group_based_adaptive_bloating(X,
|
||||
P,
|
||||
M,
|
||||
stage_weights,
|
||||
recorsive=False)
|
||||
else:
|
||||
simulated_pieces = compute_piece_counts(X, P, stage_weights)
|
||||
simulated_deployment = jsq_placement(X, simulated_pieces, M,
|
||||
stage_weights)
|
||||
|
||||
pieces = group_based_adaptive_bloating_kernel(
|
||||
X.astype(np.float32),
|
||||
P,
|
||||
M,
|
||||
simulated_pieces.astype(np.int32),
|
||||
simulated_deployment.astype(np.int32),
|
||||
stage_weights.astype(np.float32),
|
||||
)
|
||||
|
||||
deployment = jsq_placement(X, pieces, M, stage_weights)
|
||||
|
||||
X_all = X.sum(0)
|
||||
unit_load = np.divide(X_all,
|
||||
pieces,
|
||||
out=np.zeros_like(X_all, dtype=float),
|
||||
where=pieces != 0)
|
||||
load = unit_load[deployment].sum(-1)
|
||||
|
||||
sim_unit_load = X_all / simulated_pieces
|
||||
sim_load = sim_unit_load[simulated_deployment].sum(-1)
|
||||
|
||||
if load.max() > sim_load.max():
|
||||
return simulated_deployment, simulated_pieces
|
||||
return deployment, pieces
|
||||
|
||||
def need_update(self, current_par, layer_id=0):
|
||||
threshold = self.par_history.get(layer_id, 0.0)
|
||||
return current_par >= self.threshold_ratio * threshold
|
||||
|
||||
def compute_stage_weight(self, hotness):
|
||||
n_stage = hotness.shape[0]
|
||||
stage_weights = np.zeros(n_stage)
|
||||
for i in range(n_stage):
|
||||
stage_weights[i] = hotness[i].sum()
|
||||
|
||||
stage_weights = stage_weights / stage_weights.max()
|
||||
return stage_weights
|
||||
|
||||
def rebalance_layer(self, deployment, hotness, layer_id=0):
|
||||
num_rank, expert_per_rank = deployment.shape
|
||||
num_expert = np.unique(deployment.reshape(-1)).shape[0]
|
||||
num_of_redundant_expert = num_rank * expert_per_rank - num_expert
|
||||
|
||||
current_par = self.compute_rank_load(deployment, hotness)
|
||||
|
||||
if not self.need_update(current_par, layer_id):
|
||||
return deployment, current_par, current_par
|
||||
|
||||
stage_weights = self.compute_stage_weight(hotness)
|
||||
new_deployment, _ = self.group_based_adaptive_bloating(
|
||||
hotness,
|
||||
num_expert + num_of_redundant_expert,
|
||||
num_rank,
|
||||
stage_weights,
|
||||
recorsive=False,
|
||||
)
|
||||
if np.any(new_deployment < 0):
|
||||
print(f"{new_deployment=}")
|
||||
new_par = self.compute_rank_load(new_deployment, hotness)
|
||||
|
||||
return new_deployment, new_par, current_par
|
||||
|
||||
def register_hotness(self, deployment, rank_load, num_layer, num_expert):
|
||||
for layer in range(num_layer):
|
||||
if layer not in self.hotness_window:
|
||||
self.hotness_window[layer] = deque(
|
||||
maxlen=self.max_stage_window)
|
||||
hotness = self.compute_expert_hotness(num_expert,
|
||||
deployment[layer],
|
||||
rank_load[layer])
|
||||
self.hotness_window[layer].append(hotness)
|
||||
|
||||
def compress_by_avg_pooling_fast_nd(self, arr, m):
|
||||
n, d = arr.shape
|
||||
idx = (np.arange(n) * m // n)
|
||||
result = np.zeros((m, d))
|
||||
counts = np.zeros((m, 1))
|
||||
np.add.at(result, idx, arr)
|
||||
np.add.at(counts, idx, 1)
|
||||
return result / counts
|
||||
|
||||
def rebalance_experts(self, current_expert_table, expert_workload):
|
||||
current_deployment = np.array(current_expert_table)
|
||||
expert_workload = np.array(expert_workload)
|
||||
expert_workload += 1
|
||||
num_layer = expert_workload.shape[0]
|
||||
num_expert = np.unique(current_expert_table[0].reshape(-1)).shape[0]
|
||||
self.register_hotness(current_deployment, expert_workload, num_layer,
|
||||
num_expert)
|
||||
|
||||
new_deployment = current_deployment.copy()
|
||||
|
||||
layers_need_update = np.arange(num_layer)
|
||||
|
||||
new_par = np.zeros(layers_need_update.shape[0])
|
||||
current_par = np.zeros(layers_need_update.shape[0])
|
||||
for i, layer in enumerate(layers_need_update):
|
||||
hotness = np.array(self.hotness_window[layer])
|
||||
if hotness.shape[0] > self.max_stage_window:
|
||||
hotness = self.compress_by_avg_pooling_fast_nd(
|
||||
hotness, self.max_stage_window)
|
||||
|
||||
(
|
||||
new_deployment[layer],
|
||||
new_par[i],
|
||||
current_par[i],
|
||||
) = self.rebalance_layer(current_deployment[layer],
|
||||
hotness,
|
||||
layer_id=layer)
|
||||
|
||||
priority = new_par / current_par
|
||||
priority_idx = np.argsort(priority)
|
||||
priority_idx = priority_idx[priority[priority_idx] <
|
||||
1][:self.buffer_expert_layer_num]
|
||||
|
||||
if np.all(expert_workload == 1):
|
||||
for _, layer in enumerate(layers_need_update):
|
||||
self.hotness_window[layer].pop()
|
||||
return False, np.array([], dtype=int), current_deployment
|
||||
change = len(priority_idx) > 0
|
||||
if change:
|
||||
for idx in priority_idx:
|
||||
self.par_history[layers_need_update[idx]] = new_par[idx]
|
||||
|
||||
layers_need_update = priority_idx
|
||||
deployment = current_deployment
|
||||
for layer in layers_need_update:
|
||||
deployment[layer] = auto_fix_new_placement(
|
||||
current_deployment[layer], new_deployment[layer])
|
||||
|
||||
return change, layers_need_update, deployment
|
||||
|
||||
|
||||
def generate_layered_experts(num_layers=58,
|
||||
layer_shape=(32, 9),
|
||||
expert_min=0,
|
||||
expert_max=255):
|
||||
"""
|
||||
Generate expert deployment matrix meeting the following conditions:
|
||||
- Total of num_layers layers
|
||||
- Each layer has shape layer_shape (32,9)
|
||||
- Each expert from expert_min to expert_max (0 to 255) appears at least once in each layer
|
||||
|
||||
Args:
|
||||
num_layers: Number of layers, default 58
|
||||
layer_shape: Shape of a single layer, default (32,9)
|
||||
expert_min: Minimum expert ID, default 0
|
||||
expert_max: Maximum expert ID, default 255
|
||||
Returns:
|
||||
torch.Tensor: Tensor with shape (num_layers, layer_shape[0], layer_shape[1])
|
||||
"""
|
||||
# 1. Basic parameter calculation
|
||||
expert_num = expert_max - expert_min + 1 # Total number of experts: 256 (0~255)
|
||||
layer_total = layer_shape[0] * layer_shape[
|
||||
1] # Total elements in a single layer: 32*9=288
|
||||
extra_slots = layer_total - expert_num # Number of random positions to fill per layer: 288-256=32
|
||||
|
||||
# 2. Verify feasibility (total elements must be ≥ number of experts to cover all experts)
|
||||
assert layer_total >= expert_num, (
|
||||
f"Number of elements in a single layer {layer_total} < number of experts {expert_num}, "
|
||||
"cannot cover all experts")
|
||||
|
||||
# 3. Generate layers one by one
|
||||
layers = []
|
||||
for _ in range(num_layers):
|
||||
# 3.1 Generate "complete expert sequence" (ensure each expert from 0 to 255 is included)
|
||||
full_experts = torch.arange(expert_min,
|
||||
expert_max + 1,
|
||||
dtype=torch.int64) # shape (256,)
|
||||
|
||||
# 3.2 Generate "supplementary random experts" (fill remaining 32 positions, randomly selected from 0~255)
|
||||
extra_experts = torch.randint(expert_min,
|
||||
expert_max + 1,
|
||||
size=(extra_slots, ),
|
||||
dtype=torch.int64) # shape (32,)
|
||||
|
||||
# 3.3 Concatenate and shuffle (ensure random distribution of experts in each layer)
|
||||
layer_flat = torch.cat([full_experts, extra_experts],
|
||||
dim=0) # shape (288,)
|
||||
# Shuffle order (use randperm to generate random indices to avoid repeated shuffling issues)
|
||||
shuffle_idx = torch.randperm(layer_flat.shape[0])
|
||||
layer_shuffled = layer_flat[shuffle_idx]
|
||||
|
||||
# 3.4 Reshape to layer_shape (32,9)
|
||||
layer = layer_shuffled.reshape(layer_shape)
|
||||
layers.append(layer)
|
||||
|
||||
# 4. Stack all layers to get the final tensor
|
||||
return torch.stack(layers, dim=0) # shape (58,32,9)
|
||||
|
||||
|
||||
def warm_up():
|
||||
exam_config = DynamicConfig()
|
||||
exam_config.ep_worldsize = 32
|
||||
exam_config.num_die_per_host = 16
|
||||
algo = FlashLB(exam_config)
|
||||
# Generate target tensor
|
||||
expert_tensor = generate_layered_experts(num_layers=58,
|
||||
layer_shape=(32, 9))
|
||||
|
||||
algo.rebalance_experts(expert_tensor, torch.randint(1, 1000, (58, 32, 9)))
|
||||
30
vllm_ascend/eplb/core/policy/policy_random.py
Normal file
30
vllm_ascend/eplb/core/policy/policy_random.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# Copyright # Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
|
||||
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy.
|
||||
import copy
|
||||
import random
|
||||
|
||||
from .policy_abstract import DynamicConfig, EplbPolicy
|
||||
|
||||
random.seed(42)
|
||||
|
||||
|
||||
class RandomLoadBalance(EplbPolicy):
|
||||
|
||||
def __init__(self, config: DynamicConfig):
|
||||
super().__init__(config)
|
||||
|
||||
def rebalance_experts(self, current_expert_table, expert_workload):
|
||||
new_table = copy.deepcopy(current_expert_table)
|
||||
num_layers = len(current_expert_table)
|
||||
|
||||
for i in range(num_layers):
|
||||
# randomly choose two card
|
||||
# indices = random.sample(range(num_card), 2)
|
||||
indices = [3, 1]
|
||||
|
||||
# swap redundant experts
|
||||
expert_id_to_exchange = new_table[i][indices[0]][-1].clone()
|
||||
new_table[i][indices[0]][-1] = new_table[i][indices[1]][-1]
|
||||
new_table[i][indices[1]][-1] = expert_id_to_exchange
|
||||
|
||||
return 1, [-i for i in range(num_layers)], new_table
|
||||
205
vllm_ascend/eplb/eplb_updator.py
Normal file
205
vllm_ascend/eplb/eplb_updator.py
Normal file
@@ -0,0 +1,205 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this updator.
|
||||
import numpy
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
|
||||
|
||||
|
||||
class EplbUpdator:
|
||||
|
||||
def __init__(self, ascend_config, loader, eplb_process: EplbProcess,
|
||||
process):
|
||||
self.ascend_config = ascend_config
|
||||
self.init_eplb(self.ascend_config.expert_map_path, process)
|
||||
self.eplb_loader = loader
|
||||
self.eplb_process = eplb_process
|
||||
self.shared_dict = self.eplb_process.shared_dict
|
||||
|
||||
def set_adaptor(self, adaptor):
|
||||
self.adaptor = adaptor
|
||||
self.num_moe_layers = self.adaptor.num_moe_layers
|
||||
self.global_expert_num = self.adaptor.global_expert_num
|
||||
|
||||
def init_eplb(self, expert_map_path, process):
|
||||
self.rank_id = dist.get_rank()
|
||||
self.num_expert_load_gather = 10
|
||||
self.periodic_load_gather = True
|
||||
self.num_iterations_eplb_update: torch.int64 = self.ascend_config.num_iterations_eplb_update
|
||||
self.expert_map_path = expert_map_path
|
||||
self.expert_map_record_path = self.ascend_config.expert_map_record_path
|
||||
|
||||
try:
|
||||
if not envs.VLLM_ALLOW_EXPERT_LOAD_COLLECTING:
|
||||
self.num_expert_load_gather = self.num_iterations_eplb_update
|
||||
self.periodic_load_gather = False
|
||||
except Exception:
|
||||
self.num_expert_load_gather = self.num_iterations_eplb_update
|
||||
self.periodic_load_gather = False
|
||||
|
||||
self.expert_map_initialized = False
|
||||
self.gate_eplb = self.ascend_config.gate_eplb
|
||||
|
||||
self.reqs = []
|
||||
self.update_info_all = []
|
||||
|
||||
self.cur_iterations: torch.int64 = 0
|
||||
|
||||
self.num_wait_worker_iterations: torch.int64 = self.ascend_config.num_wait_worker_iterations
|
||||
|
||||
self.process = process
|
||||
|
||||
logger.info(
|
||||
f"[ModelRunner] Launched EPLB process (pid={self.process.pid})")
|
||||
|
||||
def update_iteration(self):
|
||||
self.cur_iterations += 1
|
||||
if self.cur_iterations == (self.num_iterations_eplb_update + \
|
||||
self.num_wait_worker_iterations + self.num_moe_layers):
|
||||
if self.expert_map_record_path is not None:
|
||||
self.adaptor._export_tensor_to_file(
|
||||
self.shared_dict["expert_maps"],
|
||||
self.expert_map_record_path)
|
||||
|
||||
self.adaptor.model.clear_all_moe_loads()
|
||||
if not self.gate_eplb:
|
||||
self.cur_iterations = 0
|
||||
|
||||
def get_update_info_flag(self):
|
||||
return self.cur_iterations == (self.num_iterations_eplb_update +
|
||||
self.num_wait_worker_iterations - 1)
|
||||
|
||||
def wakeup_eplb_worker_flag(self):
|
||||
return self.cur_iterations == (self.num_iterations_eplb_update - 1)
|
||||
|
||||
def update_expert_weight_flag(self):
|
||||
weight_update_counter = self.cur_iterations - (
|
||||
self.num_iterations_eplb_update + self.num_wait_worker_iterations)
|
||||
return (weight_update_counter >= 0
|
||||
and weight_update_counter < self.num_moe_layers)
|
||||
|
||||
def get_init_expert_map(self):
|
||||
try:
|
||||
if not self.expert_map_initialized:
|
||||
self.shared_dict[
|
||||
"expert_maps"] = self.adaptor.get_init_expert_map_from_file(
|
||||
self.num_moe_layers, self.expert_map_path)
|
||||
self.expert_map_initialized = True
|
||||
except Exception as e:
|
||||
logger.warning(f"[ModelRunner] Failed to wake EPLB process: {e}",
|
||||
exc_info=True)
|
||||
|
||||
def wakeup_eplb_worker(self):
|
||||
self.eplb_process.planner_q.put(1)
|
||||
|
||||
def forward_before(self):
|
||||
if self.update_expert_weight_flag():
|
||||
(expert_send_info, expert_recv_info, updated_expert_map,
|
||||
log2phy_map, layer_id) = self.update_info_all.pop(0)
|
||||
log2phy_map_this_rank = torch.from_numpy(numpy.array(log2phy_map))
|
||||
self.eplb_loader.set_log2phy_map(log2phy_map_this_rank)
|
||||
updated_expert_map_this_rank = torch.from_numpy(
|
||||
numpy.array(updated_expert_map))
|
||||
self.eplb_loader.generate_expert_d2d_transfer_task(
|
||||
expert_send_info, expert_recv_info,
|
||||
updated_expert_map_this_rank,
|
||||
layer_id + self.adaptor.num_dense_layers)
|
||||
|
||||
# set asynchronous stream for d2d expert weight update
|
||||
self.reqs = []
|
||||
self.eplb_loader.asyn_expert_weight_transfer(self.reqs)
|
||||
|
||||
def take_update_info_from_eplb_process(self):
|
||||
# Batch after eplb process being triggered, get update info provided by eplb process
|
||||
if self.get_update_info_flag():
|
||||
self.update_info_all = self.eplb_process.block_update_q.get()
|
||||
|
||||
def forward_end(self):
|
||||
if self.wakeup_eplb_worker_flag():
|
||||
self.compute_and_set_moe_load(is_clear=True)
|
||||
self.wakeup_eplb_worker()
|
||||
|
||||
if self.update_expert_weight_flag():
|
||||
self.eplb_loader.update_expert_map_and_weight(self.reqs)
|
||||
|
||||
self.update_iteration()
|
||||
|
||||
def compute_and_set_moe_load(self, is_clear=False):
|
||||
local_load = self.adaptor.get_rank_expert_workload()
|
||||
|
||||
self._gather_buffer = None
|
||||
if dist.is_initialized():
|
||||
self.world_size = dist.get_world_size()
|
||||
self.device = local_load.device
|
||||
if self._gather_buffer is None:
|
||||
shape = (self.world_size, *local_load.shape)
|
||||
self._gather_buffer = torch.empty(shape,
|
||||
dtype=local_load.dtype,
|
||||
device=self.device)
|
||||
|
||||
dist.all_gather_into_tensor(self._gather_buffer, local_load)
|
||||
|
||||
moe_load = self._gather_buffer.permute(1, 0, 2)
|
||||
self.shared_dict["moe_load"] = moe_load.cpu()
|
||||
logger.debug(
|
||||
f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}"
|
||||
)
|
||||
else:
|
||||
moe_load = local_load.unsqueeze(1)
|
||||
self.shared_dict["moe_load"] = moe_load.cpu()
|
||||
logger.debug(
|
||||
f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}"
|
||||
)
|
||||
return moe_load
|
||||
|
||||
def warm_up_eplb(self):
|
||||
|
||||
self.get_init_expert_map()
|
||||
self.compute_and_set_moe_load()
|
||||
|
||||
src_tensor = torch.empty((1, ), device=self.device)
|
||||
self_rank = dist.get_rank()
|
||||
|
||||
comm_op_list = []
|
||||
|
||||
for dst_rank in range(self.world_size):
|
||||
if dst_rank == self_rank:
|
||||
continue
|
||||
comm_op_list.append(dist.P2POp(dist.isend, src_tensor, dst_rank))
|
||||
|
||||
for src_rank in range(self.world_size):
|
||||
if src_rank == self_rank:
|
||||
continue
|
||||
comm_op_list.append(dist.P2POp(dist.irecv, src_tensor, src_rank))
|
||||
if comm_op_list:
|
||||
reqs = dist.batch_isend_irecv(comm_op_list)
|
||||
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
Clean up the EPLB process.
|
||||
"""
|
||||
if self.process.is_alive():
|
||||
self.process.terminate()
|
||||
self.process.join()
|
||||
logger.info("[ModelRunner] EPLB process terminated")
|
||||
77
vllm_ascend/eplb/utils.py
Normal file
77
vllm_ascend/eplb/utils.py
Normal file
@@ -0,0 +1,77 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Todo: Once https://github.com/vllm-project/vllm/pull/23553 is merged in vllm. Remove this model register.
|
||||
import types
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_expert_map(self, layer_id):
|
||||
return self.model.layers[layer_id].mlp.experts.get_map()
|
||||
|
||||
|
||||
def get_log2phy_map(self, layer_id):
|
||||
return self.model.layers[layer_id].mlp.experts.get_log2phy_map()
|
||||
|
||||
|
||||
def get_all_expert_map(self, num_moe_layers):
|
||||
all_loads = []
|
||||
num_dense_layers = self.num_dense_layers if hasattr(
|
||||
self, "num_dense_layers") else 0
|
||||
for layer_id in range(num_moe_layers):
|
||||
load_tensor = self.get_expert_map(
|
||||
layer_id + num_dense_layers) # (num_experts_per_layer,)
|
||||
all_loads.append(load_tensor)
|
||||
|
||||
return torch.stack(all_loads, dim=0)
|
||||
|
||||
|
||||
def get_all_moe_loads(self):
|
||||
num_dense_layers = self.num_dense_layers if hasattr(
|
||||
self, "num_dense_layers") else 0
|
||||
all_moe_loads = torch.stack(
|
||||
[self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load \
|
||||
for layer_id in range(self.num_moe_layers)],
|
||||
dim=0
|
||||
)
|
||||
return all_moe_loads
|
||||
|
||||
|
||||
def clear_all_moe_loads(self):
|
||||
num_dense_layers = self.num_dense_layers if hasattr(
|
||||
self, "num_dense_layers") else 0
|
||||
for layer_id in range(self.num_moe_layers):
|
||||
self.model.layers[layer_id +
|
||||
num_dense_layers].mlp.experts.clear_moe_load()
|
||||
|
||||
|
||||
def model_register(model, model_config):
|
||||
model.get_expert_map = types.MethodType(get_expert_map, model)
|
||||
model.get_log2phy_map = types.MethodType(get_log2phy_map, model)
|
||||
model.get_all_expert_map = types.MethodType(get_all_expert_map, model)
|
||||
model.get_all_moe_loads = types.MethodType(get_all_moe_loads, model)
|
||||
model.clear_all_moe_loads = types.MethodType(clear_all_moe_loads, model)
|
||||
|
||||
config = model_config.hf_config
|
||||
|
||||
if config.model_type == "qwen3_moe":
|
||||
model.num_moe_layers = config.num_hidden_layers
|
||||
elif config.model_type == "deepseek_v2" or config.model_type == "deepseek_v3":
|
||||
num_dense_layers = config.first_k_dense_replace
|
||||
model.num_moe_layers = config.num_hidden_layers - num_dense_layers
|
||||
else:
|
||||
raise NotImplementedError("EPLB is not supported.")
|
||||
@@ -21,7 +21,7 @@ def bgmv_shrink(inputs: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0):
|
||||
return torch.ops._C.bgmv_shrink(
|
||||
return torch.ops._C_ascend.bgmv_shrink(
|
||||
inputs,
|
||||
lora_a_weights,
|
||||
lora_indices_tensor,
|
||||
@@ -35,7 +35,7 @@ def bgmv_expand(inputs: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True):
|
||||
return torch.ops._C.bgmv_expand(
|
||||
return torch.ops._C_ascend.bgmv_expand(
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
lora_indices_tensor,
|
||||
@@ -52,9 +52,9 @@ def bgmv_expand_slice(inputs: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True):
|
||||
return torch.ops._C.bgmv_expand(inputs, lora_b_weights,
|
||||
lora_indices_tensor, output_tensor,
|
||||
slice_offset, slice_size)
|
||||
return torch.ops._C_ascend.bgmv_expand(inputs, lora_b_weights,
|
||||
lora_indices_tensor, output_tensor,
|
||||
slice_offset, slice_size)
|
||||
|
||||
|
||||
def sgmv_shrink(
|
||||
@@ -69,9 +69,9 @@ def sgmv_shrink(
|
||||
token_nums: int,
|
||||
scaling: float,
|
||||
):
|
||||
return torch.ops._C.sgmv_shrink(inputs, lora_a_weights,
|
||||
lora_indices_tensor, seq_len_tensor,
|
||||
output_tensor, scaling)
|
||||
return torch.ops._C_ascend.sgmv_shrink(inputs, lora_a_weights,
|
||||
lora_indices_tensor, seq_len_tensor,
|
||||
output_tensor, scaling)
|
||||
|
||||
|
||||
def sgmv_expand(inputs: torch.Tensor,
|
||||
@@ -84,7 +84,7 @@ def sgmv_expand(inputs: torch.Tensor,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
add_inputs: bool = False):
|
||||
return torch.ops._C.sgmv_expand(
|
||||
return torch.ops._C_ascend.sgmv_expand(
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
lora_indices_tensor,
|
||||
@@ -107,6 +107,7 @@ def sgmv_expand_slice(inputs: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = False):
|
||||
return torch.ops._C.sgmv_expand(inputs, lora_b_weights,
|
||||
lora_indices_tensor, seq_len_tensor,
|
||||
output_tensor, slice_offset, slice_size)
|
||||
return torch.ops._C_ascend.sgmv_expand(inputs, lora_b_weights,
|
||||
lora_indices_tensor, seq_len_tensor,
|
||||
output_tensor, slice_offset,
|
||||
slice_size)
|
||||
@@ -11,12 +11,14 @@ if is_310p():
|
||||
bgmv_shrink, sgmv_expand,
|
||||
sgmv_expand_slice, sgmv_shrink)
|
||||
else:
|
||||
from vllm_ascend.lora.punica_wrapper.lora_ops import (
|
||||
bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand,
|
||||
sgmv_expand_slice, sgmv_shrink)
|
||||
from vllm_ascend.lora.lora_ops import (bgmv_expand, bgmv_expand_slice,
|
||||
bgmv_shrink, sgmv_expand,
|
||||
sgmv_expand_slice, sgmv_shrink)
|
||||
|
||||
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
|
||||
|
||||
from vllm_ascend.lora.utils import refresh_all_lora_classes
|
||||
|
||||
|
||||
# The platforms that are compatible with the PyTorch-native implementation can
|
||||
# inherit this class
|
||||
@@ -31,6 +33,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
device: Union[torch.device, str], **kwargs):
|
||||
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
|
||||
device)
|
||||
refresh_all_lora_classes()
|
||||
|
||||
def _shrink_prefill(
|
||||
self,
|
||||
@@ -338,13 +341,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
x = x.view(-1, x.shape[-1])
|
||||
|
||||
if lora_a_stacked.dim() == 2:
|
||||
lora_a_stacked = lora_a_stacked.unsqueeze(0)
|
||||
if lora_b_stacked.dim() == 2:
|
||||
lora_b_stacked = lora_b_stacked.unsqueeze(0)
|
||||
|
||||
r = lora_a_stacked.size(-1)
|
||||
r = lora_b_stacked.size(-1)
|
||||
|
||||
if buffer is None:
|
||||
buffer = torch.zeros((x.size(0), r),
|
||||
@@ -352,13 +349,8 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
device=x.device)
|
||||
|
||||
indices = self.sampler_indices
|
||||
if indices.max() >= lora_a_stacked.size(0):
|
||||
indices = torch.clamp(indices, 0, lora_a_stacked.size(0) - 1)
|
||||
|
||||
lora_a_reshaped = lora_a_stacked.transpose(1, 2)
|
||||
lora_b_reshaped = lora_b_stacked.transpose(1, 2)
|
||||
|
||||
bgmv_shrink(x, lora_a_reshaped, buffer, indices, scale)
|
||||
bgmv_expand(buffer, lora_b_reshaped, y, indices, add_inputs=True)
|
||||
bgmv_shrink(x, lora_a_stacked, buffer, indices, scale)
|
||||
bgmv_expand(buffer, lora_b_stacked, y, indices, add_inputs=True)
|
||||
|
||||
y = y.view_as(y_org)
|
||||
110
vllm_ascend/lora/utils.py
Normal file
110
vllm_ascend/lora/utils.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from typing import Optional
|
||||
|
||||
import vllm
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
MergedQKVParallelLinearWithLoRA,
|
||||
QKVParallelLinearWithLoRA,
|
||||
RowParallelLinearWithLoRA,
|
||||
VocabParallelEmbeddingWithLoRA)
|
||||
from vllm.lora.layers.utils import _not_fully_sharded_can_replace
|
||||
|
||||
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
|
||||
AscendMergedColumnParallelLinear,
|
||||
AscendQKVParallelLinear,
|
||||
AscendRowParallelLinear)
|
||||
from vllm_ascend.ops.vocab_parallel_embedding import \
|
||||
AscendVocabParallelEmbedding
|
||||
|
||||
|
||||
class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendColumnParallelLinear
|
||||
|
||||
|
||||
class AscendMergedColumnParallelLinearWithLoRA(
|
||||
MergedColumnParallelLinearWithLoRA):
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendMergedColumnParallelLinear
|
||||
|
||||
|
||||
class AscendRowParallelLinearWithLoRA(RowParallelLinearWithLoRA):
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendRowParallelLinear
|
||||
|
||||
|
||||
class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA):
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendVocabParallelEmbedding
|
||||
|
||||
|
||||
class AscendQKVParallelLinearWithLoRA(QKVParallelLinearWithLoRA):
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
return type(source_layer) is AscendQKVParallelLinear and len(
|
||||
packed_modules_list) == 1
|
||||
|
||||
|
||||
class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA):
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return (type(source_layer) is AscendQKVParallelLinear
|
||||
and len(packed_modules_list) == 3)
|
||||
|
||||
|
||||
def refresh_all_lora_classes():
|
||||
vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(
|
||||
AscendMergedColumnParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(
|
||||
AscendMergedQKVParallelLinearWithLoRA)
|
||||
@@ -23,7 +23,7 @@ from torch.library import Library
|
||||
# Do NOT perform any real computation or allocate device memory.
|
||||
#
|
||||
# 2. Register your meta function using `register_meta_if_necessary`, providing:
|
||||
# - The namespace (usually "_C" for custom ops)
|
||||
# - The namespace (usually "_C_ascend" for custom ops)
|
||||
# - The operator name (as registered in C++)
|
||||
# - The Python meta function
|
||||
# - (Optional) The overload name, if your op has overloads
|
||||
@@ -39,7 +39,7 @@ from torch.library import Library
|
||||
#
|
||||
# For more details, see: https://pytorch.org/docs/stable/notes/extending.html#meta-tensors
|
||||
|
||||
lib = Library("_C", "IMPL")
|
||||
lib = Library("_C_ascend", "IMPL")
|
||||
|
||||
|
||||
def register_meta_if_necessary(ns: str, op_name: str, fn, overload: str = ""):
|
||||
@@ -97,8 +97,9 @@ def sgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor,
|
||||
return y_out
|
||||
|
||||
|
||||
register_meta_if_necessary("_C", "rotary_embedding", rotary_embedding_meta)
|
||||
register_meta_if_necessary("_C", "get_masked_input_and_mask",
|
||||
register_meta_if_necessary("_C_ascend", "rotary_embedding",
|
||||
rotary_embedding_meta)
|
||||
register_meta_if_necessary("_C_ascend", "get_masked_input_and_mask",
|
||||
get_masked_input_and_mask_meta)
|
||||
register_meta_if_necessary("_C", "bgmv_expand", bgmv_expand_meta)
|
||||
register_meta_if_necessary("_C", "sgmv_expand", sgmv_expand_meta)
|
||||
register_meta_if_necessary("_C_ascend", "bgmv_expand", bgmv_expand_meta)
|
||||
register_meta_if_necessary("_C_ascend", "sgmv_expand", sgmv_expand_meta)
|
||||
|
||||
@@ -4,23 +4,20 @@ import vllm_ascend.envs as envs_ascend
|
||||
|
||||
|
||||
def register_model():
|
||||
from .deepseek_dbo import CustomDeepseekDBOForCausalLM # noqa: F401
|
||||
from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401
|
||||
from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401
|
||||
from .deepseek_v3 import CustomDeepseekV3ForCausalLM # noqa: F401
|
||||
from .qwen2_5_vl import \
|
||||
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
|
||||
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
|
||||
from .qwen3 import CustomQwen3ForCausalLM # noqa: F401
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepSeekMTPModel",
|
||||
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen2VLForConditionalGeneration",
|
||||
"vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3VLMoeForConditionalGeneration",
|
||||
"vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen3VLMoeForConditionalGeneration"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3VLForConditionalGeneration",
|
||||
"vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen3VLForConditionalGeneration"
|
||||
)
|
||||
|
||||
if envs_ascend.USE_OPTIMIZED_MODEL:
|
||||
ModelRegistry.register_model(
|
||||
"Qwen2_5_VLForConditionalGeneration",
|
||||
@@ -32,30 +29,32 @@ def register_model():
|
||||
"vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen2_5_VLForConditionalGeneration_Without_Padding"
|
||||
)
|
||||
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_DBO:
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV2ForCausalLM",
|
||||
"vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV2ForCausalLM",
|
||||
"vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV3ForCausalLM",
|
||||
"vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
|
||||
else:
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV2ForCausalLM",
|
||||
"vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM")
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV3ForCausalLM",
|
||||
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV3ForCausalLM",
|
||||
"vllm_ascend.models.deepseek_v3:CustomDeepseekV3ForCausalLM")
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV32ForCausalLM",
|
||||
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepSeekMTPModel",
|
||||
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3MoeForCausalLM",
|
||||
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM")
|
||||
|
||||
# There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization
|
||||
# to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM.
|
||||
ModelRegistry.register_model(
|
||||
"PanguProMoEForCausalLM",
|
||||
"vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM")
|
||||
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
|
||||
)
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3NextForCausalLM",
|
||||
"vllm_ascend.models.qwen3_next:CustomQwen3NextForCausalLM")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,22 +23,20 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
||||
get_current_vllm_config)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.deepseek_mtp import (
|
||||
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
|
||||
SharedHead)
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .deepseek_v2 import CustomDeepseekV2DecoderLayer
|
||||
|
||||
|
||||
class CustomDeepSeekShareHead(SharedHead):
|
||||
|
||||
@@ -65,6 +63,7 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@@ -75,10 +74,8 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "shared_head"))
|
||||
self.mtp_block = CustomDeepseekV2DecoderLayer(config, prefix,
|
||||
model_config,
|
||||
cache_config,
|
||||
quant_config)
|
||||
self.mtp_block = DeepseekV2DecoderLayer(vllm_config=vllm_config,
|
||||
prefix=prefix)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -103,8 +100,6 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
|
||||
|
||||
hidden_states, residual = self.mtp_block(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=None)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
@@ -171,7 +166,7 @@ class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
sampling_metadata=None, # type: ignore
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
@@ -183,14 +178,6 @@ class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
|
||||
|
||||
|
||||
class CustomDeepSeekMTP(DeepSeekMTP):
|
||||
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
|
||||
# NOTE 2.The description file generated by the current msmodelslim tool does not have
|
||||
# MTP layer info. Please manually add it and set the value to FLOAT.
|
||||
packed_modules_mapping = {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
@@ -199,8 +186,6 @@ class CustomDeepSeekMTP(DeepSeekMTP):
|
||||
prefix=maybe_prefix(
|
||||
prefix, "model"))
|
||||
|
||||
self.sampler = get_sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -215,4 +200,4 @@ class CustomDeepSeekMTP(DeepSeekMTP):
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, previous_hidden_states,
|
||||
inputs_embeds, spec_step_idx)
|
||||
return hidden_states
|
||||
return hidden_states
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,27 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2ForCausalLM
|
||||
|
||||
|
||||
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
|
||||
pass
|
||||
|
||||
0
vllm_ascend/models/layers/__init__.py
Normal file
0
vllm_ascend/models/layers/__init__.py
Normal file
180
vllm_ascend/models/layers/mla.py
Normal file
180
vllm_ascend/models/layers/mla.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendMLAModules:
|
||||
q_a_proj: Optional[torch.nn.Module]
|
||||
q_a_layernorm: Optional[torch.nn.Module]
|
||||
q_proj: Optional[torch.nn.Module]
|
||||
kv_a_proj_with_mqa: torch.nn.Module
|
||||
kv_a_layernorm: torch.nn.Module
|
||||
kv_b_proj: torch.nn.Module
|
||||
o_proj: torch.nn.Module
|
||||
rotary_emb: torch.nn.Module
|
||||
|
||||
|
||||
class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
enable_shared_expert_dp: bool,
|
||||
debug_layer_idx: int,
|
||||
first_k_dense_replace: int,
|
||||
tp_size: int,
|
||||
mla_modules: AscendMLAModules,
|
||||
num_local_heads: int,
|
||||
scaling: float,
|
||||
layers: int,
|
||||
kv_lora_rank: int,
|
||||
qk_rope_head_dim: int,
|
||||
q_lora_rank: Optional[int],
|
||||
qk_nope_head_dim: int,
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = hidden_size
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
self.debug_layer_idx = debug_layer_idx
|
||||
self.first_k_dense_replace = first_k_dense_replace
|
||||
self.tp_size = tp_size
|
||||
self.num_local_heads = num_local_heads
|
||||
self.layers = layers
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_head_dim = qk_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.prefix = prefix
|
||||
|
||||
self.mla_attn = Attention(
|
||||
num_heads=self.num_local_heads,
|
||||
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
scale=scaling,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_mla=True,
|
||||
# MLA Args
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
rotary_emb=mla_modules.rotary_emb,
|
||||
q_a_proj=mla_modules.q_a_proj,
|
||||
q_a_layernorm=mla_modules.q_a_layernorm,
|
||||
q_proj=mla_modules.q_proj,
|
||||
kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa,
|
||||
kv_a_layernorm=mla_modules.kv_a_layernorm,
|
||||
kv_b_proj=mla_modules.kv_b_proj,
|
||||
o_proj=mla_modules.o_proj,
|
||||
)
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
need_gather_q_kv = False
|
||||
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
|
||||
# Simulate all gather to calculate output shape
|
||||
num_tokens = num_tokens * self.tp_size
|
||||
need_gather_q_kv = True
|
||||
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
|
||||
output_shape = hidden_states.shape
|
||||
else:
|
||||
rows = num_tokens // self.tp_size
|
||||
if num_tokens % self.tp_size:
|
||||
rows += 1
|
||||
output_shape = (rows, hidden_states.shape[1])
|
||||
# FIXME: This does not seem right, should make sure the buffer is fixed
|
||||
output = torch.empty(output_shape,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
torch.ops.vllm.mla_forward(hidden_states, need_gather_q_kv, output,
|
||||
self.prefix)
|
||||
output = output.view(-1, output_shape[-1])
|
||||
return output
|
||||
|
||||
|
||||
def mla_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
need_gather_q_kv: bool,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
if forward_context.attn_metadata:
|
||||
attn_metadata = forward_context.attn_metadata[self.mla_attn.layer_name]
|
||||
else:
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine]
|
||||
self.mla_attn.impl.forward(self.mla_attn.layer_name, hidden_states,
|
||||
kv_cache, attn_metadata, need_gather_q_kv,
|
||||
output)
|
||||
return
|
||||
|
||||
|
||||
def mla_forward_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
need_gather_q_kv: bool,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="mla_forward",
|
||||
op_func=mla_forward,
|
||||
mutates_args=["output"],
|
||||
fake_impl=mla_forward_fake,
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
233
vllm_ascend/models/layers/sfa.py
Normal file
233
vllm_ascend/models/layers/sfa.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendSFAModules:
|
||||
q_a_proj: Optional[torch.nn.Module]
|
||||
q_a_layernorm: Optional[torch.nn.Module]
|
||||
q_proj: Optional[torch.nn.Module]
|
||||
kv_a_proj_with_mqa: torch.nn.Module
|
||||
kv_a_layernorm: torch.nn.Module
|
||||
kv_b_proj: torch.nn.Module
|
||||
o_proj: torch.nn.Module
|
||||
rotary_emb: torch.nn.Module
|
||||
indexer: torch.nn.Module
|
||||
|
||||
|
||||
class AscendSparseFlashAttention(MultiHeadLatentAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
enable_shared_expert_dp: bool,
|
||||
debug_layer_idx: int,
|
||||
first_k_dense_replace: int,
|
||||
tp_size: int,
|
||||
sfa_modules: AscendSFAModules,
|
||||
num_local_heads: int,
|
||||
scaling: float,
|
||||
layers: int,
|
||||
kv_lora_rank: int,
|
||||
qk_rope_head_dim: int,
|
||||
q_lora_rank: Optional[int],
|
||||
qk_nope_head_dim: int,
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = hidden_size
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
self.debug_layer_idx = debug_layer_idx
|
||||
self.first_k_dense_replace = first_k_dense_replace
|
||||
self.tp_size = tp_size
|
||||
self.num_local_heads = num_local_heads
|
||||
self.layers = layers
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_head_dim = qk_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.prefix = prefix
|
||||
|
||||
self.sfa_attn = Attention(
|
||||
num_heads=self.num_local_heads,
|
||||
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
scale=scaling,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_mla=True,
|
||||
use_sfa=True,
|
||||
# SFA Args
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
rotary_emb=sfa_modules.rotary_emb,
|
||||
q_a_proj=sfa_modules.q_a_proj,
|
||||
q_a_layernorm=sfa_modules.q_a_layernorm,
|
||||
q_proj=sfa_modules.q_proj,
|
||||
kv_a_proj_with_mqa=sfa_modules.kv_a_proj_with_mqa,
|
||||
kv_a_layernorm=sfa_modules.kv_a_layernorm,
|
||||
kv_b_proj=sfa_modules.kv_b_proj,
|
||||
o_proj=sfa_modules.o_proj,
|
||||
indexer=sfa_modules.indexer)
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
need_gather_q_kv = False
|
||||
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
|
||||
# Simulate all gather to calculate output shape
|
||||
num_tokens = num_tokens * self.tp_size
|
||||
need_gather_q_kv = True
|
||||
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
|
||||
output_shape = hidden_states.shape
|
||||
else:
|
||||
rows = num_tokens // self.tp_size
|
||||
if num_tokens % self.tp_size:
|
||||
rows += 1
|
||||
output_shape = (rows, hidden_states.shape[1])
|
||||
# FIXME: This does not seem right, should make sure the buffer is fixed
|
||||
output = torch.empty(output_shape,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
torch.ops.vllm.sfa_forward(hidden_states, need_gather_q_kv, output,
|
||||
self.prefix)
|
||||
output = output.view(-1, output_shape[-1])
|
||||
return output
|
||||
|
||||
|
||||
def sfa_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
need_gather_q_kv: bool,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
if forward_context.attn_metadata:
|
||||
attn_metadata = forward_context.attn_metadata[self.sfa_attn.layer_name]
|
||||
else:
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
kv_cache = self.sfa_attn.kv_cache[forward_context.virtual_engine]
|
||||
self.sfa_attn.impl.forward(hidden_states, kv_cache, attn_metadata,
|
||||
need_gather_q_kv, output)
|
||||
return
|
||||
|
||||
|
||||
class Indexer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
dim: int = 7168,
|
||||
n_heads: int = 64,
|
||||
head_dim: int = 128,
|
||||
index_topk: int = 2048,
|
||||
q_lora_rank: int = 1536,
|
||||
rope_head_dim: int = 64,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: Optional[str] = ""):
|
||||
super().__init__()
|
||||
|
||||
self.dim: int = dim # 7168
|
||||
self.n_heads: int = n_heads # 64
|
||||
self.head_dim: int = head_dim # 128
|
||||
self.rope_head_dim: int = rope_head_dim # 64
|
||||
self.index_topk: int = index_topk # 2048
|
||||
self.q_lora_rank: int = q_lora_rank # 1536
|
||||
self.wq_b = ReplicatedLinear(
|
||||
self.q_lora_rank,
|
||||
self.n_heads * self.head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wq_b",
|
||||
return_bias=False,
|
||||
)
|
||||
self.wk = ReplicatedLinear(
|
||||
self.dim,
|
||||
self.head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wk",
|
||||
return_bias=False,
|
||||
)
|
||||
self.weights_proj = ReplicatedLinear(
|
||||
self.dim,
|
||||
self.n_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.weights_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
self.k_norm = nn.LayerNorm(self.head_dim)
|
||||
self.softmax_scale = self.head_dim**-0.5
|
||||
|
||||
def forward(self):
|
||||
return
|
||||
|
||||
|
||||
def sfa_forward_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
need_gather_q_kv: bool,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="sfa_forward",
|
||||
op_func=sfa_forward,
|
||||
mutates_args=["output"],
|
||||
fake_impl=sfa_forward_fake,
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -42,6 +42,8 @@ from vllm.model_executor.models.qwen2_5_vl import (
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
MIN_PAD_SIZE = 64 # min_size to pad weight
|
||||
MAX_PAD_SIZE = 128 # max_size to pad weight
|
||||
|
||||
@@ -291,6 +293,40 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
|
||||
self.hidden_size, -1)
|
||||
return out_weight
|
||||
|
||||
def pad_qkv_weight_scale_offset(self, data):
|
||||
reshaped_data = data.reshape(
|
||||
-1, 3, self.origin_hidden_size_per_attention_head, 1)
|
||||
data1 = reshaped_data[:, :, :self.
|
||||
half_origin_hidden_size_per_attention_head, :]
|
||||
data2 = reshaped_data[:, :, self.
|
||||
half_origin_hidden_size_per_attention_head:, :]
|
||||
data1_paded = torch.nn.functional.pad(
|
||||
data1, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0,
|
||||
0, 0, 0))
|
||||
data2_paded = torch.nn.functional.pad(
|
||||
data2, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0,
|
||||
0, 0, 0))
|
||||
res = torch.cat([data1_paded, data2_paded], dim=2)
|
||||
res = res.reshape(-1, 1)
|
||||
return res
|
||||
|
||||
def pad_qkv_deq_scale_quant_bias(self, data):
|
||||
reshaped_data = data.reshape(
|
||||
-1, 3, self.origin_hidden_size_per_attention_head)
|
||||
data1 = reshaped_data[:, :, :self.
|
||||
half_origin_hidden_size_per_attention_head]
|
||||
data2 = reshaped_data[:, :,
|
||||
self.half_origin_hidden_size_per_attention_head:]
|
||||
|
||||
data1_paded = torch.nn.functional.pad(
|
||||
data1, (0, self.half_pad_hidden_size_per_attention_head))
|
||||
data2_paded = torch.nn.functional.pad(
|
||||
data2, (0, self.half_pad_hidden_size_per_attention_head))
|
||||
|
||||
res = torch.cat([data1_paded, data2_paded], dim=2)
|
||||
res = res.reshape(-1)
|
||||
return res
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping: list[tuple[str, str, Union[str, int]]] = [
|
||||
@@ -318,11 +354,23 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
if ("attn.proj.weight" in name) and self.enable_pad:
|
||||
if ("attn.proj.weight_scale" in name or
|
||||
"attn.proj.weight_offset" in name) and self.enable_pad:
|
||||
continue
|
||||
elif ("attn.proj.deq_scale" in name
|
||||
or "attn.proj.quant_bias" in name) and self.enable_pad:
|
||||
continue
|
||||
elif ("attn.qkv.weight_scale" in name
|
||||
or "attn.qkv.weight_offset" in name) and self.enable_pad:
|
||||
param.data = self.pad_qkv_weight_scale_offset(param.data)
|
||||
elif ("attn.qkv.deq_scale" in name
|
||||
or "attn.qkv.quant_bias" in name) and self.enable_pad:
|
||||
param.data = self.pad_qkv_deq_scale_quant_bias(param.data)
|
||||
elif ("attn.proj.weight" in name) and self.enable_pad:
|
||||
param.data = self.pad_proj_weight(param.data)
|
||||
if ("attn.qkv.weight" in name) and self.enable_pad:
|
||||
elif ("attn.qkv.weight" in name) and self.enable_pad:
|
||||
param.data = self.pad_qkv_weight(param.data)
|
||||
if ("attn.qkv.bias" in name) and self.enable_pad:
|
||||
elif ("attn.qkv.bias" in name) and self.enable_pad:
|
||||
param.data = self.pad_qkv_bias(param.data)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
@@ -450,12 +498,20 @@ class AscendQwen2_5_VLForConditionalGeneration(
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.visual = AscendQwen2_5_VisionTransformer(
|
||||
vision_config=config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
if vllm_version_is("0.10.2"):
|
||||
self.visual = AscendQwen2_5_VisionTransformer(
|
||||
vision_config=config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
else:
|
||||
self.visual = AscendQwen2_5_VisionTransformer(
|
||||
vision_config=config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
|
||||
def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]:
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Adapted from vllm/model_executor/models/qwen2_5_vl.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# This file is a part of the vllm-ascend project.
|
||||
@@ -27,10 +26,19 @@ import torch_npu
|
||||
from einops import rearrange
|
||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
|
||||
|
||||
try:
|
||||
from transformers.models.qwen3_vl.configuration_qwen3_vl import \
|
||||
Qwen3VLConfig
|
||||
from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import \
|
||||
Qwen3VLMoeConfig
|
||||
except ImportError:
|
||||
pass
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
||||
from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY,
|
||||
get_act_and_mul_fn)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.models.qwen2_5_vl import (
|
||||
@@ -38,10 +46,29 @@ from vllm.model_executor.models.qwen2_5_vl import (
|
||||
Qwen2_5_VisionTransformer, Qwen2_5_VLDummyInputsBuilder,
|
||||
Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLMultiModalProcessor,
|
||||
Qwen2_5_VLProcessingInfo)
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
|
||||
try:
|
||||
from vllm.model_executor.models.qwen3_vl import (
|
||||
Qwen3_VisionBlock, Qwen3_VisionPatchEmbed, Qwen3_VisionTransformer,
|
||||
Qwen3VLDummyInputsBuilder, Qwen3VLForConditionalGeneration,
|
||||
Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo)
|
||||
from vllm.model_executor.models.qwen3_vl_moe import (
|
||||
Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeProcessingInfo)
|
||||
except ImportError:
|
||||
Qwen3_VisionBlock = object
|
||||
Qwen3_VisionPatchEmbed = object
|
||||
Qwen3_VisionTransformer = object
|
||||
Qwen3VLDummyInputsBuilder = object
|
||||
Qwen3VLForConditionalGeneration = object
|
||||
Qwen3VLMultiModalProcessor = object
|
||||
Qwen3VLProcessingInfo = object
|
||||
Qwen3VLMoeForConditionalGeneration = object
|
||||
Qwen3VLMoeProcessingInfo = object
|
||||
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
from vllm_ascend.models.qwen2_5_vl import AscendQwen2_5_VisionRotaryEmbedding
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
|
||||
class AscendQwen2_5_VisionAttention_Without_Padding(Qwen2_5_VisionAttention):
|
||||
@@ -112,16 +139,14 @@ class AscendQwen2_5_VisionAttention_Without_Padding(Qwen2_5_VisionAttention):
|
||||
|
||||
class AscendQwen2_5_VisionBlock_Without_Padding(Qwen2_5_VisionBlock):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_hidden_dim: int,
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
|
||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_hidden_dim: int,
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
|
||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer,
|
||||
quant_config, prefix)
|
||||
self.attn = AscendQwen2_5_VisionAttention_Without_Padding(
|
||||
@@ -321,6 +346,133 @@ class AscendQwen2_5_VisionTransformer_Without_Padding(Qwen2_5_VisionTransformer
|
||||
return x
|
||||
|
||||
|
||||
class AscendQwen3_VisionPatchEmbed(Qwen3_VisionPatchEmbed):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.matmul(
|
||||
self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1))
|
||||
x = x + self.proj.bias
|
||||
return x
|
||||
|
||||
|
||||
class AscendQwen3_VisionBlock(Qwen3_VisionBlock):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_hidden_dim: int,
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
|
||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer,
|
||||
quant_config, prefix, use_data_parallel)
|
||||
self.attn = AscendQwen2_5_VisionAttention_Without_Padding(
|
||||
embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
|
||||
cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.attn(
|
||||
self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin)
|
||||
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class AscendQwen3_VisionTransformer(Qwen3_VisionTransformer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config,
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__(vision_config, norm_eps, quant_config, prefix,
|
||||
use_data_parallel)
|
||||
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
|
||||
self.patch_embed = AscendQwen3_VisionPatchEmbed(
|
||||
patch_size=self.patch_size,
|
||||
temporal_patch_size=self.temporal_patch_size,
|
||||
in_channels=vision_config.in_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
)
|
||||
self.blocks = nn.ModuleList([
|
||||
AscendQwen3_VisionBlock(
|
||||
dim=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
mlp_hidden_dim=vision_config.intermediate_size,
|
||||
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}")
|
||||
for layer_idx in range(vision_config.depth)
|
||||
])
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
self.hidden_size, self.num_heads)
|
||||
|
||||
def cal_cos_sin(self, rotary_pos_emb):
|
||||
cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2]
|
||||
sin = rotary_pos_emb.sin()
|
||||
cos_new = torch.cat((cos, cos), dim=-1)
|
||||
sin_new = torch.cat((sin, sin), dim=-1)
|
||||
cos_new = cos_new.reshape(1, -1, 1,
|
||||
self.hidden_size_per_attention_head)
|
||||
sin_new = sin_new.reshape(1, -1, 1,
|
||||
self.hidden_size_per_attention_head)
|
||||
return cos_new, sin_new
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: list[list[int]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = x.to(device=self.device, dtype=self.dtype)
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
|
||||
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
grid_thw_tensor = torch.tensor(grid_thw,
|
||||
device=self.device,
|
||||
dtype=torch.int32)
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2],
|
||||
grid_thw_tensor[:, 0]).cpu().to(torch.int32)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
|
||||
hidden_states = hidden_states.unsqueeze(1)
|
||||
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)
|
||||
|
||||
cos, sin = self.cal_cos_sin(rotary_pos_emb)
|
||||
|
||||
deepstack_feature_lists = []
|
||||
for layer_num, blk in enumerate(self.blocks):
|
||||
hidden_states = blk(hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
cos=cos,
|
||||
sin=sin)
|
||||
if layer_num in self.deepstack_visual_indexes:
|
||||
deepstack_merger_idx = self.deepstack_visual_indexes.index(
|
||||
layer_num)
|
||||
deepstack_feature = self.deepstack_merger_list[
|
||||
deepstack_merger_idx](hidden_states)
|
||||
deepstack_feature_lists.append(deepstack_feature)
|
||||
hidden_states = self.merger(hidden_states)
|
||||
hidden_states = torch.cat(
|
||||
[hidden_states] + deepstack_feature_lists,
|
||||
dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
|
||||
return hidden_states
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
Qwen2_5_VLMultiModalProcessor,
|
||||
info=Qwen2_5_VLProcessingInfo,
|
||||
@@ -332,12 +484,20 @@ class AscendQwen2_5_VLForConditionalGeneration_Without_Padding(
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.visual = AscendQwen2_5_VisionTransformer_Without_Padding(
|
||||
vision_config=config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
if vllm_version_is("0.10.2"):
|
||||
self.visual = AscendQwen2_5_VisionTransformer_Without_Padding(
|
||||
vision_config=config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
else:
|
||||
self.visual = AscendQwen2_5_VisionTransformer_Without_Padding(
|
||||
vision_config=config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
|
||||
def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]:
|
||||
|
||||
@@ -371,3 +531,101 @@ class AscendQwen2_5_VLForConditionalGeneration_Without_Padding(
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||
return video_embeds.split(sizes.tolist())
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor,
|
||||
info=Qwen3VLProcessingInfo,
|
||||
dummy_inputs=Qwen3VLDummyInputsBuilder)
|
||||
class AscendQwen3VLForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
supports_encoder_tp_data = True
|
||||
|
||||
# To ensure correct weight loading and mapping.
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"model.visual.": "visual.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
"model.language_model.": "language_model.model.",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
config: Qwen3VLConfig = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
if vllm_version_is("0.10.2"):
|
||||
self.visual = AscendQwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel)
|
||||
else:
|
||||
self.visual = AscendQwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor,
|
||||
info=Qwen3VLMoeProcessingInfo,
|
||||
dummy_inputs=Qwen3VLDummyInputsBuilder)
|
||||
class AscendQwen3VLMoeForConditionalGeneration(
|
||||
Qwen3VLMoeForConditionalGeneration):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
supports_encoder_tp_data = True
|
||||
|
||||
# To ensure correct weight loading and mapping.
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"model.visual.": "visual.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
"model.language_model.": "language_model.model.",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
config: Qwen3VLMoeConfig = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if vllm_version_is("0.10.2"):
|
||||
self.visual = AscendQwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
else:
|
||||
self.visual = AscendQwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
|
||||
@@ -40,6 +40,8 @@ from vllm.model_executor.models.qwen2_vl import (
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
MIN_PAD_SIZE = 64 # min_size to pad weight
|
||||
MAX_PAD_SIZE = 128 # max_size to pad weight
|
||||
|
||||
@@ -343,10 +345,18 @@ class AscendQwen2VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
self.visual = AscendQwen2VisionTransformer(
|
||||
self.config.vision_config,
|
||||
norm_eps=getattr(self.config, "rms_norm_eps", 1e-6),
|
||||
quant_config=self._maybe_ignore_quant_config(
|
||||
vllm_config.quant_config),
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
if vllm_version_is("0.10.2"):
|
||||
self.visual = AscendQwen2VisionTransformer(
|
||||
self.config.vision_config,
|
||||
norm_eps=getattr(self.config, "rms_norm_eps", 1e-6),
|
||||
quant_config=self._maybe_ignore_quant_config(
|
||||
vllm_config.quant_config),
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
else:
|
||||
self.visual = AscendQwen2VisionTransformer(
|
||||
self.config.vision_config,
|
||||
norm_eps=getattr(self.config, "rms_norm_eps", 1e-6),
|
||||
quant_config=vllm_config.quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
|
||||
@@ -1,156 +0,0 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Qwen3Config
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Model
|
||||
from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader,
|
||||
PPMissingLayer, maybe_prefix)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant
|
||||
|
||||
|
||||
class CustomQwen3DecoderLayer(Qwen3DecoderLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
if quant_config is None:
|
||||
return
|
||||
|
||||
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
|
||||
assert isinstance(quant_config, AscendQuantConfig), \
|
||||
"Expected quant_config to be an instance of AscendQuantConfig"
|
||||
|
||||
if isinstance(self.self_attn.qkv_proj.quant_method.quant_method,
|
||||
AscendW8A8LinearMethod):
|
||||
self.input_layernorm = AddRMSNormW8A8Quant(
|
||||
config.hidden_size,
|
||||
layer=self.self_attn.qkv_proj,
|
||||
eps=config.rms_norm_eps)
|
||||
if isinstance(self.mlp.gate_up_proj.quant_method.quant_method,
|
||||
AscendW8A8LinearMethod):
|
||||
self.post_attention_layernorm = AddRMSNormW8A8Quant(
|
||||
config.hidden_size,
|
||||
layer=self.mlp.gate_up_proj,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
|
||||
ALL_DECODER_LAYER_TYPES = {
|
||||
"attention": CustomQwen3DecoderLayer,
|
||||
}
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"input_ids": 0,
|
||||
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
|
||||
# otherwise (seq_len, ).
|
||||
"positions": -1,
|
||||
"intermediate_tensors": 0,
|
||||
"inputs_embeds": 0,
|
||||
})
|
||||
class CustomQwen3Model(Qwen2Model):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
decoder_layer_type=CustomQwen3DecoderLayer)
|
||||
|
||||
|
||||
class CustomQwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
# add `CustomQwen3Model` to init self.model
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = CustomQwen3Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "lm_head"))
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
@@ -17,14 +17,14 @@
|
||||
# Adapted from vllm/model_executor/models/qwen3_moe.py
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, CompilationLevel, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
||||
get_tp_group)
|
||||
from vllm.forward_context import get_forward_context
|
||||
@@ -45,11 +45,8 @@ from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
|
||||
from vllm.model_executor.models.utils import (
|
||||
PPMissingLayer, extract_layer_index,
|
||||
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.ops.sequence_parallel import (MetadataForPadding,
|
||||
init_metadata_for_sp)
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
|
||||
@@ -101,7 +98,6 @@ class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
||||
self,
|
||||
hidden_states,
|
||||
attn_metadata=None,
|
||||
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
||||
):
|
||||
if attn_metadata is None:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
@@ -120,7 +116,6 @@ class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
||||
top_k=self.top_k,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
shared_experts=None,
|
||||
_metadata_for_padding=_metadata_for_padding,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
@@ -175,9 +170,14 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
if vllm_version_is("0.10.2"):
|
||||
self.mlp = Qwen3MoeSparseMoeBlock(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
@@ -189,60 +189,6 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
self.enable_sequence_parallelism = (
|
||||
vllm_config.compilation_config.pass_config.
|
||||
enable_sequence_parallelism if vllm_config is not None else False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# To prevent precision issues during the decoder phase when only prefilling enables SP
|
||||
if not self.enable_sequence_parallelism:
|
||||
self.self_attn.o_proj.reduce_results = True
|
||||
else:
|
||||
self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True
|
||||
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
residual = _metadata_for_padding.padding_slice(residual)
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
|
||||
hidden_states)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(
|
||||
hidden_states)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
if not self.use_aclgraph:
|
||||
hidden_states = self.mlp(
|
||||
hidden_states, _metadata_for_padding=_metadata_for_padding)
|
||||
else:
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class CustomQwen3MoeModel(Qwen3MoeModel):
|
||||
@@ -254,11 +200,8 @@ class CustomQwen3MoeModel(Qwen3MoeModel):
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
self.num_redundant_experts = parallel_config.num_redundant_experts
|
||||
else:
|
||||
eplb_config = parallel_config.eplb_config
|
||||
self.num_redundant_experts = eplb_config.num_redundant_experts
|
||||
eplb_config = parallel_config.eplb_config
|
||||
self.num_redundant_experts = eplb_config.num_redundant_experts
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.config = config
|
||||
@@ -281,60 +224,8 @@ class CustomQwen3MoeModel(Qwen3MoeModel):
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
_metadata_for_padding=_metadata_for_padding)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
|
||||
hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
@@ -357,7 +248,6 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
|
||||
# Set MoE hyperparameters
|
||||
self.expert_weights: list[torch.Tensor] = []
|
||||
|
||||
@@ -378,16 +268,3 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
|
||||
self.num_moe_layers = len(self.moe_layers)
|
||||
self.num_expert_groups = 1
|
||||
self.num_shared_experts = 0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
_metadata_for_padding = init_metadata_for_sp(
|
||||
input_ids, self.enable_sequence_parallelism)
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds, _metadata_for_padding)
|
||||
return hidden_states
|
||||
|
||||
676
vllm_ascend/models/qwen3_next.py
Normal file
676
vllm_ascend/models/qwen3_next.py
Normal file
@@ -0,0 +1,676 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# mypy: ignore-errors
|
||||
"""Inference-only Qwen3Next model."""
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from vllm import envs
|
||||
from vllm.attention import AttentionBackend, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
|
||||
VllmConfig, get_current_vllm_config)
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fla.ops import RMSNormGated
|
||||
from vllm.model_executor.layers.fla.ops.chunk import chunk_gated_delta_rule
|
||||
from vllm.model_executor.layers.fla.ops.fused_recurrent import \
|
||||
fused_recurrent_gated_delta_rule
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.model_executor.layers.layernorm import \
|
||||
GemmaRMSNorm as Qwen3NextRMSNorm
|
||||
# yapf: enable
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import \
|
||||
mamba_v2_sharded_weight_loader
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, sharded_weight_loader)
|
||||
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
|
||||
from vllm.model_executor.models.utils import (
|
||||
PPMissingLayer, extract_layer_index, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.transformers_utils.configs import Qwen3NextConfig
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
||||
|
||||
from vllm.model_executor.models.qwen3_next import ( # isort: skip
|
||||
Qwen3NextAttention, Qwen3NextDecoderLayer, Qwen3NextForCausalLM,
|
||||
Qwen3NextGatedDeltaNet, Qwen3NextModel, Qwen3NextSparseMoeBlock,
|
||||
fused_gdn_gating)
|
||||
|
||||
|
||||
class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
|
||||
|
||||
@property
|
||||
def mamba_type(self) -> str:
|
||||
return "linear_attention"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend
|
||||
return GDNAttentionBackend
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
|
||||
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
|
||||
self.model_config.dtype, self.cache_config.mamba_cache_dtype)
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
return MambaStateShapeCalculator.gated_delta_net_state_shape(
|
||||
self.tp_size, self.num_k_heads, self.num_v_heads, self.head_k_dim,
|
||||
self.head_v_dim, self.conv_kernel_size, self.num_spec)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3NextConfig,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
speculative_config: Optional[SpeculativeConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_v_heads = config.linear_num_value_heads
|
||||
self.num_k_heads = config.linear_num_key_heads
|
||||
self.head_k_dim = config.linear_key_head_dim
|
||||
self.head_v_dim = config.linear_value_head_dim
|
||||
self.key_dim = self.head_k_dim * self.num_k_heads
|
||||
self.value_dim = self.head_v_dim * self.num_v_heads
|
||||
|
||||
self.conv_kernel_size = config.linear_conv_kernel_dim
|
||||
self.layer_idx = extract_layer_index(prefix)
|
||||
self.activation = config.hidden_act
|
||||
self.act = ACT2FN[config.hidden_act]
|
||||
self.layer_norm_epsilon = config.rms_norm_eps
|
||||
self.prefix = prefix
|
||||
|
||||
self.config = config
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.quant_config = quant_config
|
||||
self.speculative_config = speculative_config
|
||||
self.num_spec = (self.speculative_config.num_speculative_tokens
|
||||
if self.speculative_config else 0)
|
||||
|
||||
# QKV
|
||||
self.conv_dim = self.key_dim * 2 + self.value_dim
|
||||
self.conv1d = ColumnParallelLinear(
|
||||
input_size=self.conv_kernel_size,
|
||||
output_size=self.conv_dim,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.conv1d",
|
||||
)
|
||||
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
|
||||
|
||||
# projection of the input hidden states
|
||||
self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
|
||||
self.projection_size_ba = self.num_v_heads * 2
|
||||
self.in_proj = MergedColumnParallelLinear(
|
||||
input_size=self.hidden_size,
|
||||
output_sizes=[self.projection_size_qkvz, self.projection_size_ba],
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.in_proj",
|
||||
)
|
||||
|
||||
query_key_settings = (self.key_dim, 0, False)
|
||||
value_settings = (self.value_dim, 0, False)
|
||||
|
||||
delattr(self.conv1d.weight, "weight_loader")
|
||||
set_weight_attrs(
|
||||
self.conv1d.weight, {
|
||||
"weight_loader":
|
||||
mamba_v2_sharded_weight_loader([
|
||||
query_key_settings,
|
||||
query_key_settings,
|
||||
value_settings,
|
||||
], self.tp_size, self.tp_rank)
|
||||
})
|
||||
|
||||
# selective projection used to make dt, B and C input dependent
|
||||
|
||||
# time step projection (discretization)
|
||||
# instantiate once and copy inv_dt in init_weights of PretrainedModel
|
||||
self.dt_bias = nn.Parameter(
|
||||
torch.ones(self.num_v_heads // self.tp_size), )
|
||||
self.A_log = nn.Parameter(
|
||||
torch.empty(
|
||||
divide(self.num_v_heads, self.tp_size),
|
||||
dtype=torch.float32,
|
||||
))
|
||||
|
||||
set_weight_attrs(self.A_log,
|
||||
{"weight_loader": sharded_weight_loader(0)})
|
||||
set_weight_attrs(self.dt_bias,
|
||||
{"weight_loader": sharded_weight_loader(0)})
|
||||
|
||||
self.norm = RMSNormGated(
|
||||
self.head_v_dim,
|
||||
eps=self.layer_norm_epsilon,
|
||||
norm_before_gate=True,
|
||||
device="npu",
|
||||
)
|
||||
|
||||
self.out_proj = RowParallelLinear(self.value_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj")
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
):
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
|
||||
if attn_metadata is None:
|
||||
# V1 profile run
|
||||
return
|
||||
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, GDNAttentionMetadata)
|
||||
has_initial_state = attn_metadata.has_initial_state
|
||||
spec_query_start_loc = attn_metadata.spec_query_start_loc
|
||||
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
|
||||
spec_sequence_masks = attn_metadata.spec_sequence_masks
|
||||
spec_token_masks = attn_metadata.spec_token_masks
|
||||
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
|
||||
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
|
||||
num_actual_tokens = (attn_metadata.num_prefill_tokens +
|
||||
attn_metadata.num_decode_tokens +
|
||||
attn_metadata.num_spec_decode_tokens)
|
||||
num_accepted_tokens = attn_metadata.num_accepted_tokens
|
||||
|
||||
# 1. Set up dimensions for reshapes later
|
||||
projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens])
|
||||
if spec_token_masks is not None:
|
||||
spec_token_masks = spec_token_masks[:num_actual_tokens]
|
||||
projected_states_qkvz, projected_states_ba = torch.split(
|
||||
projected_states,
|
||||
[
|
||||
self.projection_size_qkvz // self.tp_size,
|
||||
self.projection_size_ba // self.tp_size
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
query, key, value, z, b, a = self.fix_query_key_value_ordering(
|
||||
projected_states_qkvz, projected_states_ba)
|
||||
query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'),
|
||||
(query, key, value))
|
||||
mixed_qkv = torch.cat((query, key, value), dim=-1)
|
||||
|
||||
# 2. Convolution sequence transformation
|
||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||
self.conv1d.weight.size(2))
|
||||
|
||||
if spec_sequence_masks is not None:
|
||||
if (attn_metadata.num_prefills == 0
|
||||
and attn_metadata.num_decodes == 0):
|
||||
mixed_qkv_spec = mixed_qkv
|
||||
mixed_qkv_non_spec = None
|
||||
else:
|
||||
mixed_qkv_spec = mixed_qkv[spec_token_masks]
|
||||
mixed_qkv_non_spec = mixed_qkv[~spec_token_masks]
|
||||
else:
|
||||
mixed_qkv_spec = None
|
||||
mixed_qkv_non_spec = mixed_qkv
|
||||
|
||||
# 2.2: process the remaining part
|
||||
if attn_metadata.num_prefills > 0:
|
||||
# - "cache_indices" updates the conv_state cache in positions
|
||||
# pointed to by "mamba_cache_params.state_indices_tensor"
|
||||
mixed_qkv_non_spec = causal_conv1d_fn(
|
||||
mixed_qkv_non_spec.transpose(0, 1),
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_state,
|
||||
cache_indices=non_spec_state_indices_tensor,
|
||||
query_start_loc=non_spec_query_start_loc,
|
||||
).transpose(0, 1)
|
||||
elif attn_metadata.num_decodes > 0:
|
||||
mixed_qkv_non_spec = causal_conv1d_update(
|
||||
mixed_qkv_non_spec,
|
||||
conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=non_spec_state_indices_tensor[:attn_metadata
|
||||
.num_decodes],
|
||||
# validate_data=True,
|
||||
)
|
||||
else:
|
||||
mixed_qkv_non_spec = None
|
||||
|
||||
query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(
|
||||
mixed_qkv_spec)
|
||||
query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(
|
||||
mixed_qkv_non_spec)
|
||||
|
||||
beta = b.sigmoid()
|
||||
g = fused_gdn_gating(self.A_log, a, self.dt_bias)
|
||||
g, beta = map(lambda x: rearrange(x, 'l d -> 1 l d'), (g, beta))
|
||||
|
||||
if spec_sequence_masks is not None:
|
||||
if (attn_metadata.num_prefills == 0
|
||||
and attn_metadata.num_decodes == 0):
|
||||
g_spec = g
|
||||
beta_spec = beta
|
||||
g_non_spec = None
|
||||
beta_non_spec = None
|
||||
else:
|
||||
g_spec = g[:, spec_token_masks]
|
||||
beta_spec = beta[:, spec_token_masks]
|
||||
g_non_spec = g[:, ~spec_token_masks]
|
||||
beta_non_spec = beta[:, ~spec_token_masks]
|
||||
else:
|
||||
g_spec = None
|
||||
beta_spec = None
|
||||
g_non_spec = g
|
||||
beta_non_spec = beta
|
||||
|
||||
# 3. Recurrent attention
|
||||
# 3.1: process the mutlti-query part
|
||||
if spec_sequence_masks is not None:
|
||||
core_attn_out_spec, last_recurrent_state = (
|
||||
fused_recurrent_gated_delta_rule(
|
||||
q=query_spec,
|
||||
k=key_spec,
|
||||
v=value_spec,
|
||||
g=g_spec,
|
||||
beta=beta_spec,
|
||||
initial_state=ssm_state,
|
||||
inplace_final_state=True,
|
||||
cu_seqlens=spec_query_start_loc[:attn_metadata.
|
||||
num_spec_decodes + 1],
|
||||
ssm_state_indices=spec_state_indices_tensor,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
))
|
||||
else:
|
||||
core_attn_out_spec, last_recurrent_state = None, None
|
||||
|
||||
# 3.2: process the remaining part
|
||||
if attn_metadata.num_prefills > 0:
|
||||
initial_state = ssm_state[
|
||||
non_spec_state_indices_tensor].contiguous()
|
||||
initial_state[~has_initial_state, ...] = 0
|
||||
|
||||
batch_size = initial_state.shape[0]
|
||||
core_attn_out = []
|
||||
last_recurrent_state = []
|
||||
|
||||
for b_idx in range(batch_size):
|
||||
start, end = non_spec_query_start_loc[
|
||||
b_idx], non_spec_query_start_loc[b_idx + 1]
|
||||
cur_q = query_non_spec[:, start:end, ...]
|
||||
cur_k = key_non_spec[:, start:end, ...]
|
||||
cur_v = value_non_spec[:, start:end, ...]
|
||||
cur_g = g_non_spec[:, start:end, ...]
|
||||
cur_b = beta_non_spec[:, start:end, ...]
|
||||
cur_state = initial_state[b_idx].unsqueeze(0)
|
||||
|
||||
(
|
||||
cur_core_attn_out_non_spec,
|
||||
cur_last_recurrent_state,
|
||||
) = chunk_gated_delta_rule(
|
||||
query=cur_q,
|
||||
key=cur_k,
|
||||
value=cur_v,
|
||||
g=cur_g,
|
||||
beta=cur_b,
|
||||
initial_state=cur_state,
|
||||
output_final_state=True,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
|
||||
core_attn_out.append(cur_core_attn_out_non_spec)
|
||||
last_recurrent_state.append(cur_last_recurrent_state)
|
||||
|
||||
tar_dtype = core_attn_out[0].dtype
|
||||
tar_device = core_attn_out[0].device
|
||||
tar_shape = list(core_attn_out[0].shape)
|
||||
tar_shape[1] = non_spec_query_start_loc[-1]
|
||||
core_attn_out_non_spec = torch.empty(tar_shape,
|
||||
dtype=tar_dtype,
|
||||
device=tar_device)
|
||||
for b_idx in range(batch_size):
|
||||
cur_core_attn_out = core_attn_out[b_idx]
|
||||
start, end = non_spec_query_start_loc[
|
||||
b_idx], non_spec_query_start_loc[b_idx + 1]
|
||||
core_attn_out_non_spec[:, start:end, ...] = cur_core_attn_out
|
||||
last_recurrent_state = torch.cat(last_recurrent_state, dim=0)
|
||||
|
||||
# Init cache
|
||||
ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(
|
||||
ssm_state.dtype)
|
||||
elif attn_metadata.num_decodes > 0:
|
||||
core_attn_out_non_spec, last_recurrent_state = (
|
||||
fused_recurrent_gated_delta_rule(
|
||||
q=query_non_spec,
|
||||
k=key_non_spec,
|
||||
v=value_non_spec,
|
||||
g=g_non_spec,
|
||||
beta=beta_non_spec,
|
||||
initial_state=ssm_state,
|
||||
inplace_final_state=True,
|
||||
cu_seqlens=non_spec_query_start_loc[:attn_metadata.
|
||||
num_decodes + 1],
|
||||
ssm_state_indices=non_spec_state_indices_tensor,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
))
|
||||
else:
|
||||
core_attn_out_non_spec, last_recurrent_state = None, None
|
||||
|
||||
# Merge core attention output
|
||||
if (spec_sequence_masks is not None
|
||||
and core_attn_out_non_spec is not None):
|
||||
core_attn_out = torch.empty(
|
||||
(1, num_actual_tokens, *core_attn_out_spec.shape[2:]),
|
||||
dtype=core_attn_out_non_spec.dtype,
|
||||
device=core_attn_out_non_spec.device,
|
||||
)
|
||||
core_attn_out[:, spec_token_masks] = core_attn_out_spec
|
||||
core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec
|
||||
elif spec_sequence_masks is not None:
|
||||
core_attn_out = core_attn_out_spec
|
||||
else:
|
||||
core_attn_out = core_attn_out_non_spec
|
||||
|
||||
z_shape_og = z.shape
|
||||
# reshape input data into 2D tensor
|
||||
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
core_attn_out = self.norm(core_attn_out, z)
|
||||
core_attn_out = core_attn_out.reshape(z_shape_og)
|
||||
core_attn_out = rearrange(core_attn_out, '... h d -> ... (h d)')
|
||||
|
||||
output[:num_actual_tokens], _ = self.out_proj(core_attn_out)
|
||||
|
||||
|
||||
class CustomQwen3NextDecoderLayer(Qwen3NextDecoderLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
layer_type: str,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
speculative_config = vllm_config.speculative_config
|
||||
|
||||
self.layer_type = layer_type
|
||||
self.layer_idx = extract_layer_index(prefix)
|
||||
|
||||
if self.layer_type == "linear_attention":
|
||||
self.linear_attn = CustomQwen3NextGatedDeltaNet(
|
||||
config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
speculative_config=speculative_config,
|
||||
prefix=f'{prefix}.linear_attn')
|
||||
elif self.layer_type == "full_attention":
|
||||
self.self_attn = Qwen3NextAttention(
|
||||
config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f'{prefix}.self_attn',
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid layer_type {self.layer_type}")
|
||||
|
||||
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
|
||||
config.mlp_only_layers)
|
||||
if (self.layer_idx not in mlp_only_layers) and (
|
||||
config.num_experts > 0 and
|
||||
(self.layer_idx + 1) % config.decoder_sparse_step == 0):
|
||||
self.mlp = Qwen3NextSparseMoeBlock(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = Qwen3NextMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.input_layernorm = Qwen3NextRMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Qwen3NextRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.layer_scale = getattr(config, "layer_scale", False)
|
||||
if self.layer_scale:
|
||||
self.attn_layer_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
1,
|
||||
1,
|
||||
config.hidden_size,
|
||||
dtype=config.torch_dtype,
|
||||
), )
|
||||
self.ffn_layer_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
1,
|
||||
1,
|
||||
config.hidden_size,
|
||||
dtype=config.torch_dtype,
|
||||
), )
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class CustomQwen3NextModel(Qwen3NextModel):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
config: Qwen3NextConfig = vllm_config.model_config.hf_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
lora_config = vllm_config.lora_config
|
||||
eplb_config = parallel_config.eplb_config
|
||||
self.num_redundant_experts = eplb_config.num_redundant_experts
|
||||
|
||||
self.config = config
|
||||
lora_vocab = ((lora_config.lora_extra_vocab_size *
|
||||
(lora_config.max_loras or 1)) if lora_config else 0)
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
|
||||
def get_layer(prefix: str):
|
||||
return CustomQwen3NextDecoderLayer(
|
||||
vllm_config,
|
||||
layer_type=config.layer_types[extract_layer_index(prefix)],
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
self.norm = Qwen3NextRMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
("in_proj", "in_proj_qkvz", 0),
|
||||
("in_proj", "in_proj_ba", 1),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
expert_params_mapping = self.get_expert_mapping()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
if name.startswith("mtp."):
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
if "mlp.experts" in name:
|
||||
continue
|
||||
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# name = apply_attn_prefix(name, params_dict)
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class CustomQwen3NextForCausalLM(Qwen3NextForCausalLM):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
lora_config = vllm_config.lora_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
assert not cache_config.enable_prefix_caching, \
|
||||
"Qwen3Next currently does not support prefix caching"
|
||||
assert envs.VLLM_USE_V1, "Qwen3Next requires VLLM_USE_V1"
|
||||
self.quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.model = CustomQwen3NextModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
# Set MoE hyperparameters
|
||||
self.expert_weights = []
|
||||
|
||||
self.moe_layers: list[FusedMoE] = []
|
||||
example_layer = None
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer, PPMissingLayer):
|
||||
continue
|
||||
|
||||
assert isinstance(layer, Qwen3NextDecoderLayer)
|
||||
if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
|
||||
example_layer = layer.mlp
|
||||
self.moe_layers.append(layer.mlp.experts)
|
||||
|
||||
if example_layer is None:
|
||||
raise RuntimeError("No Qwen3Next layer found in the model.layers.")
|
||||
|
||||
self.num_moe_layers = len(self.moe_layers)
|
||||
self.num_expert_groups = 1
|
||||
self.num_shared_experts = 0
|
||||
self.num_logical_experts = example_layer.n_logical_experts
|
||||
self.num_physical_experts = example_layer.n_physical_experts
|
||||
self.num_local_physical_experts = example_layer.n_local_physical_experts
|
||||
self.num_routed_experts = example_layer.n_routed_experts
|
||||
self.num_redundant_experts = example_layer.n_redundant_experts
|
||||
@@ -20,6 +20,7 @@ import torch
|
||||
import vllm_ascend.ops.common_fused_moe # noqa
|
||||
import vllm_ascend.ops.fused_moe # noqa
|
||||
import vllm_ascend.ops.layernorm # noqa
|
||||
import vllm_ascend.ops.register_custom_ops # noqa
|
||||
import vllm_ascend.ops.vocab_parallel_embedding # noqa
|
||||
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
|
||||
from vllm_ascend.ops.rotary_embedding import (
|
||||
@@ -34,19 +35,20 @@ class dummyFusionOp:
|
||||
|
||||
|
||||
def register_dummy_fusion_op() -> None:
|
||||
torch.ops._C.rms_norm = dummyFusionOp(name="rms_norm")
|
||||
torch.ops._C.fused_add_rms_norm = dummyFusionOp(name="fused_add_rms_norm")
|
||||
torch.ops._C.static_scaled_fp8_quant = dummyFusionOp(
|
||||
torch.ops._C_ascend.rms_norm = dummyFusionOp(name="rms_norm")
|
||||
torch.ops._C_ascend.fused_add_rms_norm = dummyFusionOp(
|
||||
name="fused_add_rms_norm")
|
||||
torch.ops._C_ascend.static_scaled_fp8_quant = dummyFusionOp(
|
||||
name="static_scaled_fp8_quant")
|
||||
torch.ops._C.dynamic_scaled_fp8_quant = dummyFusionOp(
|
||||
torch.ops._C_ascend.dynamic_scaled_fp8_quant = dummyFusionOp(
|
||||
name="dynamic_scaled_fp8_quant")
|
||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant = dummyFusionOp(
|
||||
torch.ops._C_ascend.dynamic_per_token_scaled_fp8_quant = dummyFusionOp(
|
||||
name="dynamic_per_token_scaled_fp8_quant")
|
||||
torch.ops._C.rms_norm_static_fp8_quant = dummyFusionOp(
|
||||
torch.ops._C_ascend.rms_norm_static_fp8_quant = dummyFusionOp(
|
||||
name="rms_norm_static_fp8_quant")
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant = dummyFusionOp(
|
||||
torch.ops._C_ascend.fused_add_rms_norm_static_fp8_quant = dummyFusionOp(
|
||||
name="fused_add_rms_norm_static_fp8_quant")
|
||||
torch.ops._C.rms_norm_dynamic_per_token_quant = dummyFusionOp(
|
||||
torch.ops._C_ascend.rms_norm_dynamic_per_token_quant = dummyFusionOp(
|
||||
name="rms_norm_dynamic_per_token_quant")
|
||||
|
||||
|
||||
|
||||
@@ -35,8 +35,10 @@ class AscendSiluAndMul(SiluAndMul):
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
|
||||
torch.ops.vllm.maybe_prefetch_mlp_down_proj(x)
|
||||
if is_310p():
|
||||
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
|
||||
else:
|
||||
out = torch_npu.npu_swiglu(x)
|
||||
torch.ops.vllm.maybe_wait_prefetch_done(out)
|
||||
return out
|
||||
|
||||
539
vllm_ascend/ops/casual_conv1d.py
Normal file
539
vllm_ascend/ops/casual_conv1d.py
Normal file
@@ -0,0 +1,539 @@
|
||||
# adapted from vllm/model_executor/layers/mamba/ops/casual_conv1d.py
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
|
||||
# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
|
||||
# mypy: ignore-errors
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
def causal_conv1d_ref(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
initial_states: Optional[torch.Tensor] = None,
|
||||
return_final_states: bool = False,
|
||||
final_states_out: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen)
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
initial_states: (batch, dim, width - 1)
|
||||
final_states_out: (batch, dim, width - 1)
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
dtype_in = x.dtype
|
||||
x = x.to(weight.dtype)
|
||||
seqlen = x.shape[-1]
|
||||
dim, width = weight.shape
|
||||
|
||||
if initial_states is None:
|
||||
out = F.conv1d(x,
|
||||
weight.unsqueeze(1),
|
||||
bias,
|
||||
padding=width - 1,
|
||||
groups=dim)
|
||||
else:
|
||||
x = torch.cat([initial_states, x], dim=-1)
|
||||
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
|
||||
out = out[..., :seqlen]
|
||||
if return_final_states:
|
||||
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
|
||||
dtype_in) # (batch, dim, width - 1)
|
||||
if final_states_out is not None:
|
||||
final_states_out.copy_(final_states)
|
||||
else:
|
||||
final_states_out = final_states
|
||||
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
||||
return (out, None) if not return_final_states else (out, final_states_out)
|
||||
|
||||
|
||||
def causal_conv1d_fn(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
query_start_loc: Optional[torch.Tensor] = None,
|
||||
cache_indices: Optional[torch.Tensor] = None,
|
||||
has_initial_state: Optional[torch.Tensor] = None,
|
||||
conv_states: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
|
||||
sequences are concatenated from left to right for varlen
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
query_start_loc: (batch + 1) int32
|
||||
The cumulative sequence lengths of the sequences in
|
||||
the batch, used to index into sequence. prepended by 0.
|
||||
for example: query_start_loc = torch.Tensor([0,10,16,17]),
|
||||
x.shape=(dim,17)
|
||||
cache_indices: (batch) int32
|
||||
indicates the corresponding state index,
|
||||
like so: conv_state = conv_states[cache_indices[batch_id]]
|
||||
has_initial_state: (batch) bool
|
||||
indicates whether should the kernel take the current state as initial
|
||||
state for the calculations
|
||||
conv_states: (...,dim,width - 1) itype
|
||||
updated inplace if provided
|
||||
activation: either None or "silu" or "swish"
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
bias = bias.contiguous() if bias is not None else None
|
||||
|
||||
out_ref = []
|
||||
out_ref_b = []
|
||||
seqlens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
seqlens = seqlens.tolist()
|
||||
splits = torch.split(x, seqlens, dim=-1)
|
||||
|
||||
for i in range(len(seqlens)):
|
||||
x_s = splits[i]
|
||||
if cache_indices[i] == PAD_SLOT_ID:
|
||||
continue
|
||||
out_ref_b.append(
|
||||
causal_conv1d_ref(
|
||||
x_s,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation,
|
||||
return_final_states=True,
|
||||
final_states_out=conv_states[cache_indices[i]].unsqueeze(0),
|
||||
initial_states=conv_states[cache_indices[i]]
|
||||
if has_initial_state[i] else None))
|
||||
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1))
|
||||
out_ref_tensor = torch.cat(out_ref, dim=0)
|
||||
return out_ref_tensor
|
||||
|
||||
|
||||
@triton.jit()
|
||||
def _causal_conv1d_update_kernel(
|
||||
# Pointers to matrices
|
||||
x_ptr, # (batch, dim, seqlen)
|
||||
w_ptr, # (dim, width)
|
||||
bias_ptr,
|
||||
conv_state_ptr,
|
||||
cache_seqlens_ptr, # circular buffer
|
||||
conv_state_indices_ptr,
|
||||
num_accepted_tokens_ptr,
|
||||
intermediate_conv_window_ptr,
|
||||
o_ptr, # (batch, dim, seqlen)
|
||||
# Matrix dimensions
|
||||
batch: int,
|
||||
dim: tl.constexpr,
|
||||
seqlen: tl.constexpr,
|
||||
state_len: tl.constexpr,
|
||||
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
|
||||
# Strides
|
||||
stride_x_seq: tl.constexpr,
|
||||
stride_x_dim: tl.constexpr,
|
||||
stride_x_token: tl.constexpr,
|
||||
stride_w_dim: tl.constexpr,
|
||||
stride_w_width: tl.constexpr,
|
||||
stride_conv_state_seq: tl.constexpr,
|
||||
stride_conv_state_dim: tl.constexpr,
|
||||
stride_conv_state_tok: tl.constexpr,
|
||||
stride_state_indices: tl.constexpr,
|
||||
stride_inter_seq: tl.constexpr,
|
||||
stride_inter_step: tl.constexpr,
|
||||
stride_inter_dim: tl.constexpr,
|
||||
stride_inter_win: tl.constexpr,
|
||||
stride_o_seq: tl.constexpr,
|
||||
stride_o_dim: tl.constexpr,
|
||||
stride_o_token: tl.constexpr,
|
||||
# others
|
||||
pad_slot_id: tl.constexpr,
|
||||
# Meta-parameters
|
||||
HAS_BIAS: tl.constexpr,
|
||||
KERNEL_WIDTH: tl.constexpr,
|
||||
SILU_ACTIVATION: tl.constexpr,
|
||||
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
||||
IS_SPEC_DECODING: tl.constexpr,
|
||||
NP2_STATELEN: tl.constexpr,
|
||||
USE_PAD_SLOT: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
SAVE_INTERMEDIATE: tl.constexpr,
|
||||
):
|
||||
# ruff: noqa: E501
|
||||
idx_seq = tl.program_id(0)
|
||||
if idx_seq >= batch:
|
||||
return
|
||||
|
||||
# [BLOCK_N,] elements along the feature-dimension (channel)
|
||||
idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
if IS_CONTINUOUS_BATCHING:
|
||||
# mask = idx_seq < batch
|
||||
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
|
||||
idx_seq * stride_state_indices).to(
|
||||
tl.int64)
|
||||
else:
|
||||
conv_state_batch_coord = idx_seq
|
||||
if USE_PAD_SLOT: # noqa
|
||||
if conv_state_batch_coord == pad_slot_id:
|
||||
# not processing as this is not the actual sequence
|
||||
return
|
||||
|
||||
if IS_SPEC_DECODING:
|
||||
# The rolling of conv state:
|
||||
#
|
||||
# Before forward, the conv_state is:
|
||||
# [history1, history2, ..., historyM].
|
||||
#
|
||||
# After forward, the conv_state becomes:
|
||||
# [history2, ..., historyM, draft1, draft2, ..., draftN].
|
||||
#
|
||||
# After acceptance, it becomes:
|
||||
#
|
||||
# - accept 1 tokens: [history2, ..., historyM, draft1]
|
||||
# - accept 2 tokens: [history3, ..., historyM, draft1, draft2]
|
||||
# - and so on.
|
||||
conv_state_token_offset = tl.load(num_accepted_tokens_ptr +
|
||||
idx_seq) - 1
|
||||
else:
|
||||
conv_state_token_offset = 0
|
||||
|
||||
# STEP 1: READ init_state data
|
||||
conv_states_base = (conv_state_ptr +
|
||||
(conv_state_batch_coord * stride_conv_state_seq) +
|
||||
(idx_feats * stride_conv_state_dim))
|
||||
mask_w = idx_feats < dim
|
||||
|
||||
prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok
|
||||
if KERNEL_WIDTH >= 2:
|
||||
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
||||
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||
if KERNEL_WIDTH >= 3:
|
||||
conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N]
|
||||
col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||
if KERNEL_WIDTH >= 4:
|
||||
conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N]
|
||||
col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||
if KERNEL_WIDTH == 5:
|
||||
conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N]
|
||||
#col3 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||
|
||||
# STEP 2: assume state_len > seqlen
|
||||
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
||||
|
||||
# The conv_state updates works in a sliding window manner,
|
||||
# at each forward pass, the tokens are shift by 1, so we
|
||||
# load since idx_tokens + 1.
|
||||
conv_state_ptrs_source = (
|
||||
conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) +
|
||||
conv_state_token_offset * stride_conv_state_tok +
|
||||
(idx_feats * stride_conv_state_dim)[None, :] +
|
||||
((idx_tokens + 1) * stride_conv_state_tok)[:, None]
|
||||
) # [BLOCK_M, BLOCK_N]
|
||||
mask = ((conv_state_batch_coord < num_cache_lines)
|
||||
& ((idx_tokens + seqlen) < state_len)[:, None]
|
||||
& (idx_feats < dim)[None, :])
|
||||
conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0)
|
||||
|
||||
VAL = state_len - seqlen
|
||||
x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim
|
||||
) # [BLOCK_N]
|
||||
|
||||
x_ptrs = (x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None]
|
||||
) # [BLOCK_M, BLOCK_N]
|
||||
|
||||
mask_x = ((idx_tokens - VAL >= 0)[:, None]
|
||||
& (idx_tokens - VAL < seqlen)[:, None]
|
||||
& (idx_feats < dim)[None, :]
|
||||
) # token-index # token-index # feature-index
|
||||
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
|
||||
tl.debug_barrier()
|
||||
|
||||
new_conv_state = tl.where(mask, conv_state, loaded_x)
|
||||
|
||||
conv_state_base = (conv_state_ptr +
|
||||
(conv_state_batch_coord * stride_conv_state_seq) +
|
||||
(idx_feats * stride_conv_state_dim)) # [BLOCK_N,]
|
||||
conv_state_ptrs_target = (conv_state_base +
|
||||
(idx_tokens * stride_conv_state_tok)[:, None]
|
||||
) # [BLOCK_M, BLOCK_N]
|
||||
mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :]
|
||||
tl.store(conv_state_ptrs_target, new_conv_state, mask)
|
||||
|
||||
# STEP 3: init accumulator
|
||||
if HAS_BIAS:
|
||||
bias = bias_ptr + idx_feats
|
||||
mask_bias = idx_feats < dim
|
||||
acc_preload = tl.load(bias, mask=mask_bias,
|
||||
other=0.0).to(tl.float32) # [BLOCK_N]
|
||||
else:
|
||||
acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32)
|
||||
|
||||
# STEP 4:
|
||||
# PRE-LOAD WEIGHTS
|
||||
# first kernel column, configured for weights to handle BLOCK_N features in range
|
||||
w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
|
||||
mask_w = idx_feats < dim
|
||||
if KERNEL_WIDTH >= 2:
|
||||
w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor
|
||||
w_col0 = tl.load(w_ptrs, mask_w, other=0.0)
|
||||
w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor
|
||||
w_col1 = tl.load(w_ptrs, mask_w, other=0.0)
|
||||
if KERNEL_WIDTH >= 3:
|
||||
w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor
|
||||
w_col2 = tl.load(w_ptrs, mask_w, other=0.0)
|
||||
if KERNEL_WIDTH >= 4:
|
||||
w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
|
||||
w_col3 = tl.load(w_ptrs, mask_w, other=0.0)
|
||||
|
||||
x_base_1d = x_base # starting of chunk [BLOCK_N]
|
||||
mask_x_1d = idx_feats < dim
|
||||
|
||||
# STEP 5: compute each token
|
||||
for idx_token in tl.static_range(seqlen):
|
||||
acc = acc_preload
|
||||
|
||||
matrix_w = w_col0
|
||||
matrix_x = col0
|
||||
for j in tl.static_range(KERNEL_WIDTH):
|
||||
if KERNEL_WIDTH == 2:
|
||||
if j == 1: # KERNEL_WIDTH-1:
|
||||
matrix_w = w_col1
|
||||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
|
||||
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
|
||||
elif KERNEL_WIDTH == 3:
|
||||
if j == 1:
|
||||
matrix_w = w_col1
|
||||
matrix_x = col1
|
||||
elif j == 2:
|
||||
matrix_w = w_col2
|
||||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
|
||||
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
|
||||
elif KERNEL_WIDTH == 4:
|
||||
if j == 1:
|
||||
matrix_w = w_col1
|
||||
matrix_x = col1
|
||||
elif j == 2:
|
||||
matrix_w = w_col2
|
||||
matrix_x = col2
|
||||
elif j == 3:
|
||||
matrix_w = w_col3
|
||||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
|
||||
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
|
||||
|
||||
acc += matrix_x * matrix_w # [BLOCK_N]
|
||||
|
||||
if KERNEL_WIDTH == 2:
|
||||
col0 = matrix_x
|
||||
elif KERNEL_WIDTH == 3:
|
||||
col0 = col1
|
||||
col1 = matrix_x
|
||||
elif KERNEL_WIDTH == 4:
|
||||
col0 = col1
|
||||
col1 = col2
|
||||
col2 = matrix_x
|
||||
|
||||
if SILU_ACTIVATION:
|
||||
acc = acc / (1 + tl.exp(-acc))
|
||||
# mask_1d = (idx_token < seqlen) & (
|
||||
# idx_feats < dim
|
||||
# ) # token-index # feature-index
|
||||
maskL = idx_feats < dim
|
||||
maskR = tl.full(maskL.shape, False, tl.int1)
|
||||
mask_1d = tl.where(idx_token < seqlen, maskL, maskR)
|
||||
|
||||
o_ptrs = (o_ptr + (idx_seq) * stride_o_seq +
|
||||
idx_token * stride_o_token + (idx_feats * stride_o_dim))
|
||||
|
||||
tl.store(o_ptrs, acc, mask=mask_1d)
|
||||
|
||||
if SAVE_INTERMEDIATE:
|
||||
# Save the window state after consuming this token
|
||||
# Layout: [seq(cache line), step, dim, win(K-1)]
|
||||
base_ptr = (intermediate_conv_window_ptr +
|
||||
conv_state_batch_coord * stride_inter_seq +
|
||||
idx_token * stride_inter_step +
|
||||
idx_feats * stride_inter_dim)
|
||||
if KERNEL_WIDTH >= 2:
|
||||
tl.store(base_ptr + 0 * stride_inter_win, col0, mask=mask_w)
|
||||
if KERNEL_WIDTH >= 3:
|
||||
tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w)
|
||||
if KERNEL_WIDTH >= 4:
|
||||
tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w)
|
||||
|
||||
|
||||
def causal_conv1d_update_npu(
|
||||
x: torch.Tensor,
|
||||
conv_state: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
activation: Union[bool, str, None] = None,
|
||||
cache_seqlens: Optional[torch.Tensor] = None,
|
||||
conv_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
intermediate_conv_window: Optional[torch.Tensor] = None,
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
metadata=None,
|
||||
validate_data=False,
|
||||
):
|
||||
"""
|
||||
x: (batch, dim) or (batch, dim, seqlen)
|
||||
[shape=2: single token prediction]
|
||||
[shape=3: single or multiple tokens prediction]
|
||||
conv_state: (..., dim, state_len), where state_len >= width - 1
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
cache_seqlens: (batch,), dtype int32.
|
||||
If not None, the conv_state is treated as a circular buffer.
|
||||
The conv_state will be updated by copying x to the conv_state
|
||||
starting at the index
|
||||
@cache_seqlens % state_len.
|
||||
conv_state_indices: (batch,), dtype int32
|
||||
If not None, the conv_state is a larger tensor along the batch dim,
|
||||
and we are selecting the batch coords specified by conv_state_indices.
|
||||
Useful for a continuous batching scenario.
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
out: (batch, dim) or (batch, dim, seqlen)
|
||||
"""
|
||||
if validate_data:
|
||||
assert cache_seqlens is None # not implemented yet - ok for vLLM
|
||||
assert pad_slot_id is not None
|
||||
assert x.stride(1) == 1
|
||||
if isinstance(activation, bool):
|
||||
activation = "silu" if activation is True else None
|
||||
elif activation is not None:
|
||||
assert activation in ["silu", "swish"]
|
||||
unsqueeze = x.dim() == 2
|
||||
if unsqueeze:
|
||||
# make it (batch, dim, seqlen) with seqlen == 1
|
||||
x = x.unsqueeze(-1)
|
||||
batch, dim, seqlen = x.shape
|
||||
_, width = weight.shape
|
||||
# conv_state: (..., dim, state_len), where state_len >= width - 1
|
||||
num_cache_lines, _, state_len = conv_state.size()
|
||||
|
||||
if validate_data:
|
||||
assert dim == weight.size(0)
|
||||
assert (
|
||||
conv_state.stride(-2) == 1
|
||||
), f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})"
|
||||
assert state_len >= width - 1
|
||||
# when above happens, we don't shift-left to keep any records in conv_state
|
||||
assert dim == conv_state.size(1)
|
||||
if conv_state_indices is None:
|
||||
assert conv_state.size(0) >= batch
|
||||
else:
|
||||
assert (batch, ) == conv_state_indices.shape
|
||||
|
||||
assert num_cache_lines >= batch
|
||||
assert weight.stride(1) == 1 # Need this
|
||||
assert cache_seqlens is None # not needed for vLLM - circular buffer
|
||||
|
||||
# adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o'
|
||||
out = x
|
||||
stride_w_dim, stride_w_width = weight.stride()
|
||||
|
||||
stride_x_seq, stride_x_dim, stride_x_token = x.stride(
|
||||
) # X (batch, dim, seqlen)
|
||||
|
||||
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
|
||||
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride(
|
||||
)
|
||||
stride_state_indices = (conv_state_indices.stride(0)
|
||||
if conv_state_indices is not None else 0)
|
||||
state_len = width - 1 + (seqlen - 1) # effective state_len needed
|
||||
np2_statelen = triton.next_power_of_2(state_len)
|
||||
|
||||
def grid(META):
|
||||
return (
|
||||
batch,
|
||||
triton.cdiv(dim, META["BLOCK_N"]),
|
||||
)
|
||||
|
||||
# prepare intermediate buffer strides if provided
|
||||
if intermediate_conv_window is not None:
|
||||
stride_inter_seq, stride_inter_step, stride_inter_dim, stride_inter_win = (
|
||||
intermediate_conv_window.stride(0),
|
||||
intermediate_conv_window.stride(1),
|
||||
intermediate_conv_window.stride(2),
|
||||
intermediate_conv_window.stride(3),
|
||||
)
|
||||
else:
|
||||
stride_inter_seq = stride_inter_step = stride_inter_dim = stride_inter_win = 0
|
||||
|
||||
_causal_conv1d_update_kernel[grid](
|
||||
# Pointers to matrices
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
conv_state,
|
||||
cache_seqlens,
|
||||
conv_state_indices,
|
||||
num_accepted_tokens,
|
||||
intermediate_conv_window
|
||||
if intermediate_conv_window is not None else x,
|
||||
out,
|
||||
# Matrix dimensions
|
||||
batch,
|
||||
dim,
|
||||
seqlen,
|
||||
state_len,
|
||||
num_cache_lines,
|
||||
# stride
|
||||
stride_x_seq,
|
||||
stride_x_dim,
|
||||
stride_x_token,
|
||||
stride_w_dim,
|
||||
stride_w_width,
|
||||
stride_istate_seq,
|
||||
stride_istate_dim,
|
||||
stride_istate_token,
|
||||
stride_state_indices,
|
||||
stride_inter_seq,
|
||||
stride_inter_step,
|
||||
stride_inter_dim,
|
||||
stride_inter_win,
|
||||
stride_o_seq,
|
||||
stride_o_dim,
|
||||
stride_o_token,
|
||||
# others
|
||||
pad_slot_id,
|
||||
# META
|
||||
HAS_BIAS=bias is not None,
|
||||
KERNEL_WIDTH=width,
|
||||
SILU_ACTIVATION=activation in ["silu", "swish"],
|
||||
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
|
||||
IS_SPEC_DECODING=num_accepted_tokens is not None,
|
||||
NP2_STATELEN=np2_statelen,
|
||||
USE_PAD_SLOT=pad_slot_id is not None,
|
||||
BLOCK_N=128,
|
||||
SAVE_INTERMEDIATE=intermediate_conv_window is not None,
|
||||
)
|
||||
if unsqueeze:
|
||||
out = out.squeeze(-1)
|
||||
return out
|
||||
@@ -14,212 +14,32 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
import os.path
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import CompilationLevel, get_current_vllm_config
|
||||
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
|
||||
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import \
|
||||
FusedMoEParallelConfig # isort: skip
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, UnquantizedFusedMoEMethod)
|
||||
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
|
||||
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
|
||||
AlltoAllCommImpl,
|
||||
MC2CommImpl)
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
|
||||
setup_token_dispatchers
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, vllm_version_is
|
||||
from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
|
||||
determine_default_log2phy_map)
|
||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, npu_stream_switch
|
||||
|
||||
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
|
||||
|
||||
|
||||
def fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int4_w4a8: bool = False,
|
||||
global_num_experts: Optional[int] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
# For TorchAir graph
|
||||
is_torchair: bool = False,
|
||||
# For Cube/Vector parallel
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
# For load balance
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
) -> torch.Tensor:
|
||||
# Check constraints
|
||||
assert hidden_states.shape[1] == w1.shape[1], (
|
||||
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[1]}")
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
if (use_int8_w8a8 or use_int4_w4a8):
|
||||
assert w1_scale is not None and w2_scale is not None, \
|
||||
"INT8 quantization requires weight scales."
|
||||
|
||||
w1_scale = w1_scale.to(torch.float32)
|
||||
down_scale = [w2_scale]
|
||||
down_output_dtype = w2_scale.dtype
|
||||
else:
|
||||
down_scale = None
|
||||
down_output_dtype = None
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
assert moe_comm_method is not None, "Missing communication context"
|
||||
|
||||
num_experts = w1.shape[0]
|
||||
|
||||
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type = moe_comm_method.permute(
|
||||
hidden_states, topk_ids, topk_weights, expert_map, num_experts,
|
||||
use_int8_w8a8 or use_int4_w4a8)
|
||||
|
||||
gate_up_output = torch_npu.npu_grouped_matmul(
|
||||
x=[permuted_hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=expert_tokens,
|
||||
output_dtype=torch.int32 if use_int8_w8a8 else None,
|
||||
)[0]
|
||||
|
||||
if (use_int8_w8a8 or use_int4_w4a8):
|
||||
activated_output, activated_output_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=gate_up_output,
|
||||
weight_scale=w1_scale,
|
||||
activation_scale=dynamic_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=expert_tokens,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
activated_output_scale = [activated_output_scale]
|
||||
else:
|
||||
activated_output = torch_npu.npu_swiglu(gate_up_output)
|
||||
activated_output_scale = None
|
||||
|
||||
down_output = torch_npu.npu_grouped_matmul(
|
||||
x=[activated_output],
|
||||
weight=[w2],
|
||||
scale=down_scale,
|
||||
per_token_scale=activated_output_scale,
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=expert_tokens,
|
||||
output_dtype=down_output_dtype,
|
||||
)[0]
|
||||
|
||||
moe_comm_method.unpermute(down_output, hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def fused_experts_moge(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
||||
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
|
||||
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
|
||||
topk_weights: Routing weights of shape (num_tokens, top_k).
|
||||
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
||||
top_k: Number of experts to select.
|
||||
expert_map: Expert mapping of shape (num_experts,).
|
||||
|
||||
Returns:
|
||||
hidden_states: Hidden states after routing.
|
||||
"""
|
||||
ep_size = moe_parallel_config.ep_size
|
||||
local_num_experts = global_num_experts // ep_size
|
||||
local_num_group = top_k // ep_size
|
||||
|
||||
bsz, _ = hidden_states.shape
|
||||
flatten_topk_ids = topk_ids.view(-1)
|
||||
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
||||
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
|
||||
sorted_hidden_states = hidden_states.index_select(
|
||||
0, sorted_topk_ids // local_num_group)
|
||||
|
||||
experts_id = torch.arange(0,
|
||||
local_num_experts,
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device)
|
||||
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
||||
torch.float32).sum(0)
|
||||
topk_scales = topk_weights.view(-1).index_select(
|
||||
0, sorted_topk_ids).unsqueeze(-1)
|
||||
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
||||
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||
x=[sorted_hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
if is_310p():
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
||||
torch.float16)
|
||||
else:
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
gate_up_out *= topk_scales
|
||||
|
||||
down_out_list = torch_npu.npu_grouped_matmul(
|
||||
x=[gate_up_out],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
|
||||
unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids)
|
||||
final_hidden_states = unsorted_hidden_states.reshape(
|
||||
bsz, top_k // ep_size, -1).sum(1)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def unquantized_fused_moe_init_func(self, *args, **kwargs):
|
||||
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
|
||||
|
||||
@@ -235,67 +55,7 @@ def unquantized_fused_moe_init_func(self, *args, **kwargs):
|
||||
self.use_aclgraph = (vllm_config.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE
|
||||
and not vllm_config.model_config.enforce_eager)
|
||||
|
||||
|
||||
def forward_oot_v01011(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
topk_weights, topk_ids, _ = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=1.0,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
if topk_ids.shape[1] < top_k or is_310p():
|
||||
assert global_num_experts is not None
|
||||
return fused_experts_moge(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
moe_parallel_config=self.moe.moe_parallel_config,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
self.transpose = True
|
||||
|
||||
|
||||
def forward_oot(
|
||||
@@ -321,7 +81,7 @@ def forward_oot(
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
topk_weights, topk_ids, _ = select_experts(
|
||||
topk_weights, topk_ids, row_idx = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
@@ -335,40 +95,35 @@ def forward_oot(
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
if topk_ids.shape[1] < top_k or is_310p():
|
||||
assert global_num_experts is not None
|
||||
return fused_experts_moge(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
moe_parallel_config=self.moe.moe_parallel_config,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
return moe_comm_method.fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer)
|
||||
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
|
||||
if self.transpose:
|
||||
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
|
||||
|
||||
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
||||
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
||||
|
||||
self.transpose = False
|
||||
else:
|
||||
w13_data = self._maybe_pad_weight(layer.w13_weight.data)
|
||||
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
|
||||
|
||||
w2_data = self._maybe_pad_weight(layer.w2_weight.data)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
||||
|
||||
if not is_310p():
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
@@ -378,119 +133,88 @@ def process_weights_after_loading(self, layer):
|
||||
|
||||
|
||||
class AscendFusedMoE(FusedMoE):
|
||||
moe_counter = -1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_experts,
|
||||
top_k,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
params_dtype=None,
|
||||
reduce_results=False,
|
||||
renormalize=True,
|
||||
use_grouped_topk=False,
|
||||
num_expert_group=None,
|
||||
topk_group=None,
|
||||
quant_config=None,
|
||||
tp_size=None,
|
||||
ep_size=None,
|
||||
dp_size=None,
|
||||
prefix="",
|
||||
custom_routing_function=None,
|
||||
scoring_func="softmax",
|
||||
routed_scaling_fator: float = 1.0,
|
||||
e_score_correction_bias=None,
|
||||
apply_router_weight_on_input=False,
|
||||
activation="silu",
|
||||
enable_eplb=False,
|
||||
num_redundant_experts=0,
|
||||
has_bias=False,
|
||||
):
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
super().__init__(
|
||||
num_experts,
|
||||
top_k,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
params_dtype,
|
||||
reduce_results,
|
||||
renormalize,
|
||||
use_grouped_topk,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
quant_config,
|
||||
tp_size,
|
||||
ep_size,
|
||||
dp_size,
|
||||
prefix,
|
||||
custom_routing_function,
|
||||
scoring_func,
|
||||
e_score_correction_bias,
|
||||
apply_router_weight_on_input,
|
||||
activation,
|
||||
enable_eplb,
|
||||
num_redundant_experts,
|
||||
has_bias,
|
||||
)
|
||||
else:
|
||||
super().__init__(
|
||||
num_experts,
|
||||
top_k,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
params_dtype,
|
||||
reduce_results,
|
||||
renormalize,
|
||||
use_grouped_topk,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
quant_config,
|
||||
tp_size,
|
||||
ep_size,
|
||||
dp_size,
|
||||
prefix,
|
||||
custom_routing_function,
|
||||
scoring_func,
|
||||
routed_scaling_fator,
|
||||
e_score_correction_bias,
|
||||
apply_router_weight_on_input,
|
||||
activation,
|
||||
enable_eplb,
|
||||
num_redundant_experts,
|
||||
has_bias,
|
||||
)
|
||||
|
||||
setup_token_dispatchers(self.moe_config.ep_size,
|
||||
top_k=self.top_k,
|
||||
num_experts=self.global_num_experts,
|
||||
num_local_experts=self.local_num_experts)
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
AscendFusedMoE.moe_counter += 1
|
||||
self.moe_instance_id = AscendFusedMoE.moe_counter
|
||||
self.moe_config.tp_group = get_tp_group()
|
||||
self.moe_config.dp_group = get_dp_group()
|
||||
self.moe_config.ep_group = get_ep_group()
|
||||
self.moe_config.mc2_group = get_mc2_group()
|
||||
ascend_config = get_ascend_config()
|
||||
self.dynamic_eplb = ascend_config.dynamic_eplb
|
||||
self.expert_map_path = ascend_config.expert_map_path
|
||||
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
|
||||
# static eplb initializing with expert_map_path
|
||||
if self.expert_map_path and os.path.exists(
|
||||
self.expert_map_path) and os.access(self.expert_map_path,
|
||||
os.R_OK):
|
||||
self.expert_load_balancer = ExpertLoadBalancer(
|
||||
self.expert_map_path, self.global_num_experts)
|
||||
self.local_num_experts, self.expert_map = (
|
||||
self.expert_load_balancer.get_rank_placement_map(
|
||||
self.moe_instance_id, self.ep_rank))
|
||||
self.log2phy = self.expert_load_balancer.get_rank_log2phy_map(
|
||||
self.moe_instance_id, self.ep_rank).npu()
|
||||
self.global_redundant_expert_num = (
|
||||
self.expert_load_balancer.get_global_redundant_expert_num())
|
||||
else:
|
||||
# init moe.
|
||||
self.local_num_experts, self.expert_map = determine_expert_map(
|
||||
self.ep_size, self.ep_rank, self.global_num_experts)
|
||||
# dynamic eplb initializing with not expert_map_path
|
||||
if self.dynamic_eplb:
|
||||
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
|
||||
self.local_num_experts, self.expert_map = determine_default_expert_map(
|
||||
self.global_num_experts, self.ep_size, self.ep_rank,
|
||||
self.global_redundant_expert_num)
|
||||
self.log2phy = determine_default_log2phy_map(
|
||||
self.global_num_experts, self.ep_size, self.ep_rank,
|
||||
self.global_redundant_expert_num)
|
||||
local_num_experts = (torch.sum(
|
||||
self.expert_map != -1) if self.expert_map is not None else
|
||||
self.global_num_experts)
|
||||
if self.dynamic_eplb:
|
||||
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
|
||||
|
||||
for method in {AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl}:
|
||||
setattr(
|
||||
self, method.__name__.lower(),
|
||||
method(moe_config=self.moe_config)) # type: ignore[abstract]
|
||||
setup_moe_comm_method(self.moe_config)
|
||||
|
||||
def update_expert_map(self, new_expert_map):
|
||||
self.expert_map = new_expert_map
|
||||
|
||||
def get_map(self):
|
||||
return self.expert_map
|
||||
|
||||
def get_log2phy_map(self):
|
||||
return self.logical_to_physical_map
|
||||
|
||||
def clear_moe_load(self):
|
||||
if self.moe_load is not None:
|
||||
self.moe_load.zero_()
|
||||
|
||||
def maybe_all_reduce_tensor_model_parallel(
|
||||
self, final_hidden_states: torch.Tensor):
|
||||
"""NOTE(Yizhou): This is to override the parent class method. In `mc2commimpl`,
|
||||
and `alltoallcommimpl`, we do not need to all-reduce the final outputs since
|
||||
the outputs are already aggregated across tensor parallel ranks in the
|
||||
`finalize` function. In `allgathercommimpl`, we still need to all-reduce the
|
||||
outputs since each rank only has partial outputs.
|
||||
"""
|
||||
return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states)
|
||||
|
||||
def forward_impl(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
assert self.quant_method is not None
|
||||
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_method_name = forward_context.moe_comm_method_name
|
||||
|
||||
# TODO: Can we refactor this logic to model_runner?
|
||||
# TODO: Adjusted logic to differentiate between A2 and A3, we check ep_size here since mc2 only support ep_size >= 16 on A3 now
|
||||
if self.moe_config.ep_size < 16:
|
||||
moe_comm_method_name = "allgathercommimpl"
|
||||
|
||||
forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
|
||||
|
||||
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
|
||||
hidden_states=hidden_states, router_logits=router_logits)
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
replace_allreduce=forward_context.sp_enabled)
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
@@ -514,6 +238,12 @@ class AscendFusedMoE(FusedMoE):
|
||||
logical_to_physical_map=self.logical_to_physical_map,
|
||||
logical_replica_count=self.logical_replica_count,
|
||||
)
|
||||
if isinstance(final_hidden_states, tuple):
|
||||
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
|
||||
|
||||
if self.dynamic_eplb:
|
||||
self.moe_load += expert_tokens if group_list_type else \
|
||||
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
|
||||
|
||||
final_hidden_states = forward_context.moe_comm_method.finalize(
|
||||
hidden_states=final_hidden_states,
|
||||
@@ -521,11 +251,118 @@ class AscendFusedMoE(FusedMoE):
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
def transpose_weight(self, loaded_weight, expert_data, shard_dim):
|
||||
# Ensure training and inference weight shapes match during RL weight updates
|
||||
if (
|
||||
loaded_weight.shape[1] != expert_data.shape[1] and \
|
||||
loaded_weight.shape[0] != expert_data.shape[0]
|
||||
):
|
||||
shard_dim = int(not shard_dim)
|
||||
loaded_weight = loaded_weight.transpose(0, 1).contiguous()
|
||||
return loaded_weight, shard_dim
|
||||
|
||||
def _load_w13(self,
|
||||
expert_data: torch.Tensor,
|
||||
shard_dim: int,
|
||||
shard_id: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
load_full: bool = False):
|
||||
# Index the loaded weight for tp sharding.
|
||||
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
||||
loaded_weight, shard_dim = self.transpose_weight(
|
||||
loaded_weight, expert_data, shard_dim)
|
||||
shard_size = expert_data.shape[shard_dim] // 2
|
||||
if not load_full:
|
||||
loaded_weight = loaded_weight.narrow(shard_dim,
|
||||
shard_size * tp_rank,
|
||||
shard_size)
|
||||
# Narrow parameter and load.
|
||||
# w1, gate_proj: Load into first logical weight of w13.
|
||||
if shard_id == "w1":
|
||||
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
||||
# w3, up_proj: Load into second logical weight of w13.
|
||||
else:
|
||||
assert shard_id == "w3"
|
||||
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def _load_w2(self,
|
||||
expert_data: torch.Tensor,
|
||||
shard_dim: int,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
load_full: bool = False):
|
||||
# Index the loaded weight for tp sharding.
|
||||
# down_proj: "RowParallel" so tp sharding on input_dim
|
||||
# Narrow parameter and load.
|
||||
loaded_weight, shard_dim = self.transpose_weight(
|
||||
loaded_weight, expert_data, shard_dim)
|
||||
shard_size = expert_data.shape[shard_dim]
|
||||
if not load_full:
|
||||
loaded_weight = loaded_weight.narrow(shard_dim,
|
||||
shard_size * tp_rank,
|
||||
shard_size)
|
||||
# w2, down_proj: Load into only logical weight of w2.
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shared_experts: torch.nn.Module,
|
||||
use_overlapped: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
AscendFusedMoE.__init__(self, **kwargs)
|
||||
self._shared_experts = shared_experts
|
||||
self.use_overlapped = use_overlapped
|
||||
self.shared_expert_stream = None
|
||||
ascend_config = get_ascend_config()
|
||||
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
|
||||
if self.multistream_overlap_shared_expert:
|
||||
self.shared_expert_stream = torch.npu.Stream()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
shared_out, fused_out = AscendFusedMoE.forward(
|
||||
self,
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
return shared_out, fused_out
|
||||
|
||||
def forward_impl(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
# Make sure the shared experts stream begins after hidden_states are ready.
|
||||
if self.multistream_overlap_shared_expert:
|
||||
self.shared_expert_stream.wait_stream( # type: ignore
|
||||
torch.npu.current_stream())
|
||||
with npu_stream_switch(self.shared_expert_stream,
|
||||
enabled=self.multistream_overlap_shared_expert):
|
||||
# Use a separate stream to run shared experts.
|
||||
shared_out = self._shared_experts(hidden_states)
|
||||
|
||||
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
|
||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||
fused_output = AscendFusedMoE.forward_impl(
|
||||
self,
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
# Make sure the default stream waits for the shared experts stream to finish.
|
||||
if self.multistream_overlap_shared_expert:
|
||||
torch.npu.current_stream().wait_stream(self.shared_expert_stream)
|
||||
return shared_out, fused_output
|
||||
|
||||
|
||||
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
|
||||
UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
UnquantizedFusedMoEMethod.forward_oot = forward_oot_v01011
|
||||
else:
|
||||
UnquantizedFusedMoEMethod.forward_oot = forward_oot
|
||||
UnquantizedFusedMoEMethod.forward_oot = forward_oot
|
||||
|
||||
218
vllm_ascend/ops/fla.py
Normal file
218
vllm_ascend/ops/fla.py
Normal file
@@ -0,0 +1,218 @@
|
||||
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/layernorm_gated.py
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
||||
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
||||
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
||||
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
||||
# mypy: ignore-errors
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
from vllm.model_executor.layers.fla.ops.layernorm_guard import \
|
||||
layer_norm_fwd_kernel
|
||||
|
||||
|
||||
def _layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=None,
|
||||
out=None,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
M, N = x.shape
|
||||
if group_size is None:
|
||||
group_size = N
|
||||
assert N % group_size == 0
|
||||
ngroups = N // group_size
|
||||
assert x.stride(-1) == 1
|
||||
if z is not None:
|
||||
assert z.stride(-1) == 1
|
||||
assert z.shape == (M, N)
|
||||
assert weight.shape == (N, )
|
||||
assert weight.stride(-1) == 1
|
||||
if bias is not None:
|
||||
assert bias.stride(-1) == 1
|
||||
assert bias.shape == (N, )
|
||||
# allocate output
|
||||
if out is not None:
|
||||
assert out.shape == x.shape
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
assert out.stride(-1) == 1
|
||||
mean = (torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
|
||||
if not is_rms_norm else None)
|
||||
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
||||
if group_size > BLOCK_N:
|
||||
raise RuntimeError(
|
||||
"This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
||||
grid = (M, ngroups)
|
||||
with torch.npu.device(x.device.index):
|
||||
layer_norm_fwd_kernel[grid](
|
||||
x,
|
||||
out,
|
||||
weight,
|
||||
bias,
|
||||
z,
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
out.stride(0),
|
||||
z.stride(0) if z is not None else 0,
|
||||
M,
|
||||
group_size,
|
||||
eps,
|
||||
BLOCK_N=BLOCK_N,
|
||||
NORM_BEFORE_GATE=norm_before_gate,
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
return out, mean, rstd
|
||||
|
||||
|
||||
class LayerNormFn(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
if z is not None:
|
||||
assert z.shape == x_shape_og
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
if z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
y, mean, rstd = _layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=z,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=is_rms_norm,
|
||||
)
|
||||
return y.reshape(x_shape_og)
|
||||
|
||||
|
||||
def torch_chunk_gated_delta_rule(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
g,
|
||||
beta,
|
||||
chunk_size=64,
|
||||
initial_state=None,
|
||||
output_final_state=False,
|
||||
use_qk_l2norm_in_kernel=False,
|
||||
):
|
||||
initial_dtype = query.dtype
|
||||
if use_qk_l2norm_in_kernel:
|
||||
query = F.normalize(query, p=2, dim=-1)
|
||||
key = F.normalize(key, p=2, dim=-1)
|
||||
query, key, value, beta, g = [
|
||||
x.transpose(1, 2).contiguous().to(torch.float32)
|
||||
for x in (query, key, value, beta, g)
|
||||
]
|
||||
|
||||
batch_size, sequence_length, num_heads, k_head_dim = key.shape
|
||||
v_head_dim = value.shape[-1]
|
||||
pad_size = (chunk_size - num_heads % chunk_size) % chunk_size
|
||||
query = F.pad(query, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1)
|
||||
key = F.pad(key, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1)
|
||||
value = F.pad(value, (0, 0, 0, pad_size))
|
||||
beta = F.pad(beta, (0, pad_size))
|
||||
g = F.pad(g, (0, pad_size))
|
||||
tot_heads = num_heads + pad_size
|
||||
scale = 1 / (query.shape[-1]**0.5)
|
||||
query = query * scale
|
||||
|
||||
v_beta = value * beta.unsqueeze(-1)
|
||||
k_beta = key * beta.unsqueeze(-1)
|
||||
# reshape to chunks
|
||||
query, key, value, k_beta, v_beta = [
|
||||
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
|
||||
for x in (query, key, value, k_beta, v_beta)
|
||||
]
|
||||
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
|
||||
mask = torch.triu(torch.ones(chunk_size,
|
||||
chunk_size,
|
||||
dtype=torch.bool,
|
||||
device=query.device),
|
||||
diagonal=0)
|
||||
|
||||
# chunk decay
|
||||
g = g.cumsum(dim=-1)
|
||||
decay_mask = ((g.unsqueeze(-1) -
|
||||
g.unsqueeze(-2)).tril().exp().float()).tril()
|
||||
attn = -(
|
||||
(k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
|
||||
for i in range(1, chunk_size):
|
||||
row = attn[..., i, :i].clone()
|
||||
sub = attn[..., :i, :i].clone()
|
||||
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
||||
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
|
||||
value = attn @ v_beta
|
||||
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
|
||||
|
||||
last_recurrent_state = (torch.zeros(batch_size, sequence_length,
|
||||
k_head_dim, v_head_dim).to(value) if
|
||||
initial_state is None else initial_state.to(value))
|
||||
|
||||
core_attn_out = torch.zeros_like(value)
|
||||
mask = torch.triu(torch.ones(chunk_size,
|
||||
chunk_size,
|
||||
dtype=torch.bool,
|
||||
device=query.device),
|
||||
diagonal=1)
|
||||
|
||||
# for each chunk
|
||||
for i in range(0, tot_heads // chunk_size):
|
||||
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
|
||||
attn = (q_i @ k_i.transpose(-1, -2) *
|
||||
decay_mask[:, :, i]).masked_fill_(mask, 0)
|
||||
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
||||
v_new = v_i - v_prime
|
||||
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
||||
core_attn_out[:, :, i] = attn_inter + attn @ v_new
|
||||
last_recurrent_state = (
|
||||
last_recurrent_state * g[:, :, i, -1, None, None].exp() +
|
||||
(k_i *
|
||||
(g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(
|
||||
-1, -2) @ v_new)
|
||||
|
||||
if not output_final_state:
|
||||
last_recurrent_state = None
|
||||
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0],
|
||||
core_attn_out.shape[1], -1,
|
||||
core_attn_out.shape[-1])
|
||||
core_attn_out = core_attn_out[:, :, :num_heads]
|
||||
core_attn_out = core_attn_out.transpose(1,
|
||||
2).contiguous().to(initial_dtype)
|
||||
return core_attn_out, last_recurrent_state
|
||||
@@ -19,13 +19,9 @@ import os
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
||||
get_tp_group)
|
||||
from vllm.forward_context import get_forward_context
|
||||
@@ -39,70 +35,16 @@ from vllm.model_executor.layers.quantization.base_config import \
|
||||
QuantizationConfig
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||
from vllm_ascend.distributed.communication_op import \
|
||||
data_parallel_reduce_scatter
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
|
||||
determine_default_log2phy_map)
|
||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.ops.layers.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, dispose_tensor,
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
|
||||
get_all_reduce_merge_state,
|
||||
get_rm_router_logits_state, is_310p)
|
||||
|
||||
|
||||
def unified_fused_experts_eager(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w1_scale_bias: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale_bias: Optional[torch.Tensor] = None,
|
||||
shared_experts: Optional[torch.Tensor] = None,
|
||||
shared_gate_up: Optional[Any] = None,
|
||||
shared_dequant_scale: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False):
|
||||
token_dispatcher = get_forward_context().token_dispatcher
|
||||
|
||||
results = token_dispatcher.token_dispatch(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
shared_experts=shared_experts,
|
||||
shared_gate_up=shared_gate_up,
|
||||
shared_dequant_scale=shared_dequant_scale,
|
||||
mc2_mask=mc2_mask,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
with_quant=with_quant)
|
||||
|
||||
expert_output = unified_apply_mlp(
|
||||
hidden_states=results["hidden_states"],
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=results["group_list"],
|
||||
dynamic_scale=results.get("dynamic_scale"),
|
||||
group_list_type=results.get("group_list_type"),
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
topk_scales=results.get("topk_scales"),
|
||||
with_quant=with_quant)
|
||||
final_hidden_states = token_dispatcher.token_combine(expert_output)
|
||||
return final_hidden_states
|
||||
get_rm_router_logits_state, is_310p,
|
||||
vllm_version_is)
|
||||
|
||||
|
||||
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
@@ -115,6 +57,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
get_ascend_config()
|
||||
self.dynamic_eplb = get_ascend_config().dynamic_eplb
|
||||
|
||||
try:
|
||||
device_group = get_mc2_group().device_group
|
||||
@@ -182,17 +125,19 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
if enable_force_load_balance and not self.use_aclgraph:
|
||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||
|
||||
return unified_fused_experts_eager(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
expert_map=expert_map,
|
||||
shared_experts=shared_experts,
|
||||
mc2_mask=kwargs.get(
|
||||
"mc2_mask", None),
|
||||
with_quant=False)
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
return moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
shared_experts=shared_experts,
|
||||
need_trans=True,
|
||||
dynamic_eplb=self.dynamic_eplb)
|
||||
|
||||
|
||||
class AscendFusedMoE(FusedMoE):
|
||||
@@ -290,42 +235,67 @@ class AscendFusedMoE(FusedMoE):
|
||||
self.moe_parallel_config.ep_size, is_deepseek_v3_r1)
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
expert_map_path = ascend_config.expert_map_path
|
||||
if expert_map_path and os.path.exists(expert_map_path):
|
||||
# moe expert load balance
|
||||
expert_load_balancer = ExpertLoadBalancer(expert_map_path,
|
||||
self.global_num_experts)
|
||||
self.local_num_experts, self.expert_map = \
|
||||
expert_load_balancer.get_rank_placement_map(
|
||||
self.moe_instance_id,
|
||||
get_ep_group().rank_in_group)
|
||||
self.log2phy = expert_load_balancer.get_rank_log2phy_map(
|
||||
self.moe_instance_id,
|
||||
get_ep_group().rank_in_group)
|
||||
self.global_redundant_expert_num = \
|
||||
expert_load_balancer.get_global_redundant_expert_num()
|
||||
self.dynamic_eplb = ascend_config.dynamic_eplb
|
||||
self.expert_map_path = ascend_config.expert_map_path
|
||||
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
|
||||
self.global_num_experts = num_experts + self.global_redundant_expert_num
|
||||
# static eplb initializing with expert_map_path
|
||||
if self.expert_map_path and os.path.exists(
|
||||
self.expert_map_path) and os.access(self.expert_map_path,
|
||||
os.R_OK):
|
||||
self.expert_load_balancer = ExpertLoadBalancer(
|
||||
self.expert_map_path, self.global_num_experts)
|
||||
self.local_num_experts, self.expert_map = (
|
||||
self.expert_load_balancer.get_rank_placement_map(
|
||||
self.moe_instance_id, self.ep_rank))
|
||||
self.log2phy = self.expert_load_balancer.get_rank_log2phy_map(
|
||||
self.moe_instance_id, self.ep_rank).npu()
|
||||
self.global_redundant_expert_num = (
|
||||
self.expert_load_balancer.get_global_redundant_expert_num())
|
||||
else:
|
||||
# Create a tensor of size num_experts filled with -1
|
||||
# init moe.
|
||||
self.local_num_experts, self.expert_map = determine_expert_map(
|
||||
self.ep_size,
|
||||
get_ep_group().rank_in_group, self.global_num_experts)
|
||||
self.ep_size, self.ep_rank, self.global_num_experts)
|
||||
# dynamic eplb initializing with not expert_map_path
|
||||
if self.dynamic_eplb:
|
||||
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
|
||||
self.local_num_experts, self.expert_map = determine_default_expert_map(
|
||||
self.global_num_experts, self.ep_size, self.ep_rank,
|
||||
self.global_redundant_expert_num)
|
||||
self.log2phy = determine_default_log2phy_map(
|
||||
self.global_num_experts, self.ep_size, self.ep_rank,
|
||||
self.global_redundant_expert_num)
|
||||
local_num_experts = (torch.sum(self.expert_map != -1)
|
||||
if self.expert_map is not None else num_experts)
|
||||
if self.dynamic_eplb:
|
||||
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
|
||||
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
|
||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||
raise ValueError("Only softmax scoring function is supported for "
|
||||
"non-grouped topk.")
|
||||
moe = FusedMoEConfig.make(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
# TODO (bnell): this needs to be fixed for quantized types.
|
||||
in_dtype=params_dtype,
|
||||
quant_config=quant_config)
|
||||
|
||||
if vllm_version_is("0.10.2"):
|
||||
moe = FusedMoEConfig.make(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
# TODO (bnell): this needs to be fixed for quantized types.
|
||||
in_dtype=params_dtype,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
moe = FusedMoEConfig(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
in_dtype=params_dtype,
|
||||
)
|
||||
self.moe_config = moe
|
||||
# TODO: The self.moe_config.tp_size here is not correct, fixme soon
|
||||
|
||||
if quant_config is None:
|
||||
self.quant_method = AscendUnquantizedFusedMoEMethod(moe)
|
||||
@@ -337,6 +307,11 @@ class AscendFusedMoE(FusedMoE):
|
||||
local_num_experts = torch.sum(self.expert_map != -1) \
|
||||
if self.expert_map is not None else num_experts
|
||||
|
||||
self.moe_load = None
|
||||
|
||||
if self.dynamic_eplb:
|
||||
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
|
||||
|
||||
moe_quant_params = {
|
||||
"num_experts": local_num_experts,
|
||||
"hidden_size": hidden_size,
|
||||
@@ -354,34 +329,27 @@ class AscendFusedMoE(FusedMoE):
|
||||
# NOTE: self.tp_group is not expert_tp_group
|
||||
self.tp_group = get_tp_group().device_group
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
self.token_dispatcher = None
|
||||
|
||||
ep_size = (get_ep_group().world_size if
|
||||
vllm_config.parallel_config.enable_expert_parallel else 1)
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
|
||||
setup_token_dispatchers
|
||||
setup_token_dispatchers(
|
||||
ep_size,
|
||||
top_k=self.top_k,
|
||||
num_experts=self.global_num_experts,
|
||||
num_global_redundant_experts=self.global_redundant_expert_num,
|
||||
num_local_experts=self.local_num_experts)
|
||||
self.moe_config.tp_group = get_tp_group()
|
||||
self.moe_config.dp_group = get_dp_group()
|
||||
self.moe_config.ep_group = get_ep_group()
|
||||
self.moe_config.mc2_group = get_mc2_group()
|
||||
self.moe_config.num_global_redundant_experts = self.global_redundant_expert_num
|
||||
|
||||
def naive_multicast(self, x: torch.Tensor,
|
||||
cu_tokens_across_dp_cpu: torch.Tensor):
|
||||
assert (len(x.shape) == 2)
|
||||
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
self.dp_rank - 1]
|
||||
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||
buffer[start:end, :].copy_(x)
|
||||
for idx in range(self.dp_size):
|
||||
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
|
||||
end = cu_tokens_across_dp_cpu[idx]
|
||||
get_dp_group().broadcast(buffer[start:end, :], idx)
|
||||
return buffer
|
||||
setup_moe_comm_method(self.moe_config)
|
||||
|
||||
def update_expert_map(self, new_expert_map):
|
||||
self.expert_map = new_expert_map
|
||||
|
||||
def get_map(self):
|
||||
return self.expert_map
|
||||
|
||||
def get_log2phy_map(self):
|
||||
return self.logical_to_physical_map
|
||||
|
||||
def clear_moe_load(self):
|
||||
if self.moe_load is not None:
|
||||
self.moe_load.zero_()
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -391,8 +359,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
top_k: Optional[int] = None,
|
||||
shared_experts: Optional[Any] = None,
|
||||
gate=None,
|
||||
replace_allreduce: bool = False,
|
||||
_metadata_for_padding: Optional[MetadataForPadding] = None):
|
||||
replace_allreduce: bool = False):
|
||||
|
||||
assert self.quant_method is not None
|
||||
|
||||
@@ -401,10 +368,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
else:
|
||||
real_top_k = self.top_k
|
||||
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
|
||||
forward_context = get_forward_context()
|
||||
fused_moe_state = forward_context.fused_moe_state
|
||||
mc2_mask = forward_context.mc2_mask
|
||||
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
|
||||
quantized_x_for_share, dynamic_scale_for_share = None, None
|
||||
@@ -413,74 +377,16 @@ class AscendFusedMoE(FusedMoE):
|
||||
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
||||
shared_hidden_states = shared_experts(hidden_states)
|
||||
|
||||
mc2_mask = forward_context.mc2_mask
|
||||
|
||||
enable_sp = _metadata_for_padding is not None and _metadata_for_padding.not_dummy_and_is_prefill
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if enable_sp:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
mc2_mask_sp = _metadata_for_padding.mc2_mask if _metadata_for_padding is not None else forward_context.mc2_mask
|
||||
chunk_mc2_mask = torch.tensor_split(mc2_mask_sp, tp_size, dim=0)
|
||||
mc2_mask = chunk_mc2_mask[tp_rank]
|
||||
if forward_context.sp_enabled:
|
||||
replace_allreduce = True
|
||||
|
||||
if (fused_moe_state not in [
|
||||
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
||||
FusedMoEState.NaiveMulticast
|
||||
] and not replace_allreduce):
|
||||
if fused_moe_state in {FusedMoEState.MC2}:
|
||||
padding_size = forward_context.padded_num_tokens
|
||||
else:
|
||||
# TODO: Determine if we can remove the padding
|
||||
padding_size = tp_size
|
||||
if num_tokens < padding_size and not self.enable_shared_expert_dp:
|
||||
hidden_states = nn.functional.pad(
|
||||
hidden_states, (0, 0, 0, padding_size - num_tokens))
|
||||
router_logits = nn.functional.pad(
|
||||
router_logits, (0, 0, 0, padding_size - num_tokens))
|
||||
if tp_size > 1:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
if not self.enable_shared_expert_dp:
|
||||
chunk_hidden_states = torch.tensor_split(hidden_states,
|
||||
tp_size,
|
||||
dim=0)
|
||||
chunk_router_logits = torch.tensor_split(router_logits,
|
||||
tp_size,
|
||||
dim=0)
|
||||
hidden_states = chunk_hidden_states[tp_rank]
|
||||
router_logits = chunk_router_logits[tp_rank]
|
||||
|
||||
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
|
||||
mc2_mask = chunk_mc2_mask[tp_rank]
|
||||
|
||||
if self.dp_size > 1:
|
||||
if fused_moe_state == FusedMoEState.AllGather:
|
||||
# NOTE: When in torchair graph, it has been padded in model_runner_v1
|
||||
max_tokens_across_dp = forward_context.max_tokens_across_dp
|
||||
if num_tokens < max_tokens_across_dp:
|
||||
hidden_states = nn.functional.pad(
|
||||
hidden_states,
|
||||
(0, 0, 0, max_tokens_across_dp - num_tokens))
|
||||
if not self.rm_router_logits:
|
||||
router_logits = nn.functional.pad(
|
||||
router_logits,
|
||||
(0, 0, 0, max_tokens_across_dp - num_tokens))
|
||||
hidden_states = get_dp_group().all_gather(hidden_states, 0)
|
||||
if self.rm_router_logits:
|
||||
router_logits, _ = gate(hidden_states)
|
||||
else:
|
||||
router_logits = get_dp_group().all_gather(router_logits, 0)
|
||||
|
||||
elif fused_moe_state == FusedMoEState.NaiveMulticast:
|
||||
cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_dp_cpu
|
||||
hidden_states = self.naive_multicast(hidden_states,
|
||||
cu_tokens_across_dp_cpu)
|
||||
if self.rm_router_logits:
|
||||
router_logits, _ = gate(hidden_states)
|
||||
else:
|
||||
router_logits = self.naive_multicast(
|
||||
router_logits, cu_tokens_across_dp_cpu)
|
||||
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
enable_shared_expert_dp=self.enable_shared_expert_dp,
|
||||
rm_router_logits=self.rm_router_logits,
|
||||
replace_allreduce=replace_allreduce,
|
||||
gate=gate)
|
||||
|
||||
# Matrix multiply.
|
||||
e_hidden_states = self.quant_method.apply(
|
||||
@@ -503,53 +409,27 @@ class AscendFusedMoE(FusedMoE):
|
||||
global_redundant_expert_num=self.global_redundant_expert_num,
|
||||
shared_experts=None,
|
||||
mc2_mask=mc2_mask,
|
||||
token_dispatcher=self.token_dispatcher,
|
||||
quantized_x_for_share=quantized_x_for_share,
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
)
|
||||
|
||||
group_list_type = None
|
||||
|
||||
if shared_experts:
|
||||
if isinstance(e_hidden_states, tuple):
|
||||
if isinstance(e_hidden_states,
|
||||
tuple) and len(e_hidden_states) == 2:
|
||||
e_hidden_states, shared_hidden_states = e_hidden_states
|
||||
|
||||
if (fused_moe_state not in [
|
||||
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
||||
FusedMoEState.NaiveMulticast
|
||||
] and not replace_allreduce and not self.enable_shared_expert_dp):
|
||||
if tp_size > 1:
|
||||
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
|
||||
self.tp_group)
|
||||
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
||||
dispose_tensor(e_hidden_states)
|
||||
else:
|
||||
final_hidden_states = e_hidden_states
|
||||
if num_tokens < padding_size:
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
elif self.dp_size > 1 and not self.enable_shared_expert_dp:
|
||||
if fused_moe_state == FusedMoEState.NaiveMulticast:
|
||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
self.dp_rank - 1]
|
||||
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||
final_hidden_states = get_dp_group().all_reduce(
|
||||
e_hidden_states)
|
||||
final_hidden_states = final_hidden_states[start:end, :]
|
||||
dispose_tensor(e_hidden_states)
|
||||
elif fused_moe_state == FusedMoEState.AllGather:
|
||||
final_hidden_states = data_parallel_reduce_scatter(
|
||||
e_hidden_states, dim=0)
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
dispose_tensor(e_hidden_states)
|
||||
else:
|
||||
final_hidden_states = e_hidden_states
|
||||
else:
|
||||
final_hidden_states = e_hidden_states
|
||||
if isinstance(e_hidden_states, tuple) and len(e_hidden_states) == 3:
|
||||
e_hidden_states, group_list_type, expert_tokens = e_hidden_states
|
||||
|
||||
if tp_size > 1 and not self.all_reduce_merge and fused_moe_state in [
|
||||
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
||||
FusedMoEState.NaiveMulticast
|
||||
]:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
if self.dynamic_eplb and group_list_type is not None:
|
||||
self.moe_load += expert_tokens if group_list_type else \
|
||||
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
|
||||
|
||||
final_hidden_states = forward_context.moe_comm_method.finalize(
|
||||
hidden_states=e_hidden_states,
|
||||
reduce_results=(not self.all_reduce_merge))
|
||||
|
||||
if shared_experts:
|
||||
return final_hidden_states, shared_hidden_states
|
||||
|
||||
@@ -15,50 +15,124 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
||||
|
||||
|
||||
class AddRMSNormW8A8Quant(RMSNorm):
|
||||
# Fuse AddRmsNorm and W8A8 quantization ops together
|
||||
def _addrmsnorm_forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
layer: Optional[torch.nn.Module] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
|
||||
if layer is not None and not is_310p():
|
||||
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
||||
x,
|
||||
residual,
|
||||
self.weight,
|
||||
layer.aclnn_input_scale,
|
||||
layer.aclnn_input_offset,
|
||||
epsilon=self.variance_epsilon)
|
||||
else:
|
||||
if is_310p():
|
||||
orig_dtype = residual.dtype
|
||||
x = x + residual.to(x.dtype)
|
||||
residual = x.to(orig_dtype)
|
||||
x, _ = torch_npu.npu_rms_norm(x, self.weight,
|
||||
self.variance_epsilon)
|
||||
else:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm(
|
||||
x, residual, self.weight, self.variance_epsilon)
|
||||
torch.ops.vllm.maybe_wait_prefetch_done(x)
|
||||
return x, residual
|
||||
|
||||
|
||||
class AscendRMSNorm(RMSNorm):
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
import torch_npu
|
||||
|
||||
if residual is not None:
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
|
||||
assert x.size(0) == residual.size(0)
|
||||
x, residual = _addrmsnorm_forward_oot(
|
||||
self, x, residual, self.next_need_quant_fusion_linear)
|
||||
return x, residual
|
||||
x, residual = torch_npu.npu_rms_norm(x, self.weight,
|
||||
self.variance_epsilon)
|
||||
return x
|
||||
|
||||
@property
|
||||
def next_need_quant_fusion_linear(self):
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
if not forward_context.addrmsnorm_quant_fusion_enabled or \
|
||||
forward_context.layer_idx == forward_context.num_hidden_layers:
|
||||
return None
|
||||
except AssertionError:
|
||||
return None
|
||||
|
||||
next_linear = None
|
||||
model_instance = forward_context.model_instance
|
||||
layer_idx = forward_context.layer_idx
|
||||
fusion_linear = forward_context.fusion_linear
|
||||
next_linear = None
|
||||
if fusion_linear == "qkv_dense":
|
||||
next_linear = model_instance.model.layers[
|
||||
layer_idx].self_attn.qkv_proj
|
||||
forward_context.fusion_linear = "gate_up_dense"
|
||||
elif fusion_linear == "gate_up_dense":
|
||||
next_linear = model_instance.model.layers[
|
||||
layer_idx].mlp.gate_up_proj
|
||||
forward_context.fusion_linear = "qkv_dense"
|
||||
# if prefetch_mlp_weight enabled, following accumulation operation
|
||||
# does not need to be repeated
|
||||
if not forward_context.prefetch_mlp_enabled:
|
||||
forward_context.layer_idx += 1
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
if next_linear is not None and \
|
||||
not isinstance(next_linear.quant_method.quant_method, AscendW8A8LinearMethod):
|
||||
next_linear = None
|
||||
return next_linear
|
||||
|
||||
|
||||
class AscendQuantRMSNorm(AscendRMSNorm):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
layer: torch.nn.Module,
|
||||
eps: float = 1e-6,
|
||||
var_hidden_size: Optional[int] = None,
|
||||
has_weight: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
|
||||
self.layer = layer
|
||||
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
|
||||
requires_grad=False)
|
||||
|
||||
def forward(
|
||||
def forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
import torch_npu
|
||||
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if residual is not None:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
||||
x,
|
||||
residual,
|
||||
self.weight,
|
||||
self.layer.aclnn_input_scale,
|
||||
self.layer.aclnn_input_offset,
|
||||
epsilon=self.variance_epsilon)
|
||||
return x, residual
|
||||
|
||||
x, residual = torch_npu.npu_rms_norm(x, self.weight,
|
||||
self.variance_epsilon)
|
||||
return x
|
||||
x, residual = super().forward_oot(x, residual)
|
||||
return x.add_(self.bias), residual
|
||||
return cast(torch.Tensor, super().forward_oot(x)).add_(self.bias)
|
||||
|
||||
|
||||
class AscendRMSNorm(RMSNorm):
|
||||
class AscendGemmaRMSNorm(GemmaRMSNorm):
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
@@ -73,13 +147,13 @@ class AscendRMSNorm(RMSNorm):
|
||||
orig_dtype = residual.dtype
|
||||
x = x + residual.to(x.dtype)
|
||||
residual = x.to(orig_dtype)
|
||||
x, _ = torch_npu.npu_rms_norm(x, self.weight,
|
||||
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight,
|
||||
self.variance_epsilon)
|
||||
else:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm(
|
||||
x, residual, self.weight, self.variance_epsilon)
|
||||
x, residual, 1.0 + self.weight, self.variance_epsilon)
|
||||
return x, residual
|
||||
|
||||
x, residual = torch_npu.npu_rms_norm(x, self.weight,
|
||||
self.variance_epsilon)
|
||||
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight,
|
||||
self.variance_epsilon)
|
||||
return x
|
||||
|
||||
@@ -1,45 +1,159 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
This file is a part of the vllm-ascend project.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
To customize linear communication groups or forward of classes in this file,
|
||||
extend new linear operations in linear_op.py.
|
||||
The classes in this file should not be modified, including AscendQKVParallelLinear,
|
||||
AscendMergedColumnParallelLinear, AscendMergedColumnParallelLinear,
|
||||
AscendRowParallelLinear and AscendColumnParallelLinear.
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
|
||||
ColumnParallelLinear,
|
||||
LinearBase,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.distributed import divide
|
||||
from vllm.model_executor.layers.linear import ( # noqa
|
||||
WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear, LinearBase,
|
||||
MergedColumnParallelLinear, QKVParallelLinear, QuantizeMethodBase,
|
||||
RowParallelLinear, UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import \
|
||||
QuantizationConfig
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import (
|
||||
get_mlp_tensor_model_parallel_rank,
|
||||
get_mlp_tensor_model_parallel_world_size, get_mlp_tp_group)
|
||||
from vllm_ascend.ops.linear_op import (get_column_parallel_op,
|
||||
get_row_parallel_op)
|
||||
|
||||
|
||||
class AscendMlpColumnParallelLinear(ColumnParallelLinear):
|
||||
"""Linear layer with column parallelism.
|
||||
# TODO(realliujiaxu): Remove this class after linear of vllm supports custom comm group
|
||||
class AscendLinearBase(LinearBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.skip_bias_add = skip_bias_add
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
self.quant_config = quant_config
|
||||
self.prefix = prefix
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[
|
||||
QuantizeMethodBase] = UnquantizedLinearMethod()
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self,
|
||||
prefix=prefix)
|
||||
self.return_bias = return_bias
|
||||
self.disable_tp = disable_tp
|
||||
|
||||
|
||||
class AscendQKVParallelLinear(QKVParallelLinear):
|
||||
"""Linear layers for the attention's QKV transformation.
|
||||
|
||||
Linear layers for the linear transformation of the query, key, and value
|
||||
vectors in the attention layer. The weight matrix is concatenated along
|
||||
the output dimension. The layer is parallelized along the head dimension.
|
||||
When the number of key/value heads is smaller than the number of query
|
||||
heads (e.g., multi-query/grouped-query attention), the key/value head may
|
||||
be replicated while the query heads are partitioned.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
head_size: int,
|
||||
total_num_heads: int,
|
||||
total_num_kv_heads: Optional[int] = None,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
self.custom_op, _, tp_size = get_column_parallel_op(
|
||||
disable_tp, prefix, self)
|
||||
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = head_size
|
||||
self.total_num_heads = total_num_heads
|
||||
if total_num_kv_heads is None:
|
||||
total_num_kv_heads = total_num_heads
|
||||
self.total_num_kv_heads = total_num_kv_heads
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.num_heads = divide(self.total_num_heads, tp_size)
|
||||
if tp_size >= self.total_num_kv_heads:
|
||||
self.num_kv_heads = 1
|
||||
self.num_kv_head_replicas = divide(tp_size,
|
||||
self.total_num_kv_heads)
|
||||
else:
|
||||
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
||||
self.num_kv_head_replicas = 1
|
||||
input_size = self.hidden_size
|
||||
output_size = (self.num_heads +
|
||||
2 * self.num_kv_heads) * tp_size * self.head_size
|
||||
self.output_sizes = [
|
||||
self.num_heads * self.head_size * tp_size, # q_proj
|
||||
self.num_kv_heads * self.head_size * tp_size, # k_proj
|
||||
self.num_kv_heads * self.head_size * tp_size, # v_proj
|
||||
]
|
||||
AscendColumnParallelLinear.__init__(self,
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
bias=bias,
|
||||
gather_output=False,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if self.custom_op is not None:
|
||||
return self.custom_op.apply(input_)
|
||||
|
||||
return super().forward(input_)
|
||||
|
||||
|
||||
class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
|
||||
"""Packed linear layers with column parallelism.
|
||||
|
||||
Similar to ColumnParallelLinear, but the weight matrix is concatenated
|
||||
along the output dimension. When the weight matrix is loaded, the
|
||||
different partitions are sharded separately.
|
||||
|
||||
Use the MLP tensor parallelism group in the MLP module,
|
||||
and the original TP group in other modules.
|
||||
@@ -48,73 +162,46 @@ class AscendMlpColumnParallelLinear(ColumnParallelLinear):
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
output_sizes: list[int],
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
output_sizes: Optional[list[int]] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
# Divide the weight matrix along the last dimension.
|
||||
if prefix.find("gate_up_proj") != -1:
|
||||
self.tp_size = get_mlp_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_mlp_tensor_model_parallel_rank()
|
||||
self.enable_mlp_optimze = True
|
||||
else:
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.enable_mlp_optimze = False
|
||||
self.input_size_per_partition = input_size
|
||||
self.output_size_per_partition = divide(output_size, self.tp_size)
|
||||
self.output_partition_sizes = [self.output_size_per_partition]
|
||||
# If QKV or MergedColumn, use output size of each partition.
|
||||
if hasattr(self, "output_sizes"):
|
||||
self.output_partition_sizes = [
|
||||
divide(output_size, self.tp_size)
|
||||
for output_size in self.output_sizes
|
||||
]
|
||||
LinearBase.__init__(self,
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias)
|
||||
self.custom_op, self.tp_rank, self.tp_size = get_column_parallel_op(
|
||||
disable_tp, prefix, self)
|
||||
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
|
||||
self.output_sizes = output_sizes
|
||||
assert all(output_size % self.tp_size == 0
|
||||
for output_size in output_sizes)
|
||||
AscendColumnParallelLinear.__init__(self,
|
||||
input_size=input_size,
|
||||
output_size=sum(output_sizes),
|
||||
bias=bias,
|
||||
gather_output=gather_output,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
self.gather_output = gather_output
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if self.custom_op is not None:
|
||||
return self.custom_op.apply(input_)
|
||||
|
||||
if output_sizes is None:
|
||||
output_sizes = [output_size]
|
||||
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size_per_partition,
|
||||
output_partition_sizes=self.output_partition_sizes,
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=(
|
||||
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
||||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
return super().forward(input_)
|
||||
|
||||
|
||||
class AscendMlpRowParallelLinear(RowParallelLinear):
|
||||
class AscendRowParallelLinear(RowParallelLinear):
|
||||
"""Linear layer with row parallelism.
|
||||
Use the MLP tensor parallelism group in the MLP module,
|
||||
and the original TP group in other modules.
|
||||
@@ -133,28 +220,25 @@ class AscendMlpRowParallelLinear(RowParallelLinear):
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
if prefix.find("down_proj") != -1:
|
||||
self.tp_size = get_mlp_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_mlp_tensor_model_parallel_rank()
|
||||
self.enable_mlp_optimze = True
|
||||
else:
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.enable_mlp_optimze = False
|
||||
self.custom_op, self.tp_rank, self.tp_size = get_row_parallel_op(
|
||||
disable_tp, prefix, self)
|
||||
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
|
||||
# Divide the weight matrix along the first dimension.
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
self.output_size_per_partition = output_size
|
||||
self.output_partition_sizes = [output_size]
|
||||
|
||||
LinearBase.__init__(self,
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias)
|
||||
AscendLinearBase.__init__(self,
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
@@ -184,66 +268,22 @@ class AscendMlpRowParallelLinear(RowParallelLinear):
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
if self.custom_op is not None:
|
||||
self.custom_op.update_attrs()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
is_prefill: bool = True,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if self.enable_mlp_optimze:
|
||||
tp_rank = get_mlp_tensor_model_parallel_rank()
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = get_mlp_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
bias_ = None if (self.tp_rank > 0
|
||||
or self.skip_bias_add) else self.bias
|
||||
output_parallel = self.quant_method.apply(self,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
output = get_mlp_tp_group().reduce_scatter(output_parallel, 0)
|
||||
# output = output[:num_tokens,:]
|
||||
# dispose_tensor(output_parallel)
|
||||
else:
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
if self.custom_op is not None:
|
||||
return self.custom_op.apply(input_)
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
bias_ = None if (self.tp_rank > 0
|
||||
or self.skip_bias_add) else self.bias
|
||||
output_parallel = self.quant_method.apply(self,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
return super().forward(input_)
|
||||
|
||||
|
||||
class AscendMlpMergedColumnParallelLinear(MergedColumnParallelLinear):
|
||||
"""Packed linear layers with column parallelism.
|
||||
|
||||
Similar to ColumnParallelLinear, but the weight matrix is concatenated
|
||||
along the output dimension. When the weight matrix is loaded, the
|
||||
different partitions are sharded separately.
|
||||
class AscendColumnParallelLinear(ColumnParallelLinear):
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
Use the MLP tensor parallelism group in the MLP module,
|
||||
and the original TP group in other modules.
|
||||
@@ -252,58 +292,76 @@ class AscendMlpMergedColumnParallelLinear(MergedColumnParallelLinear):
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_sizes: list[int],
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
output_sizes: Optional[list[int]] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
self.output_sizes = output_sizes
|
||||
if prefix.find("gate_up_proj") != -1:
|
||||
self.tp_size = get_mlp_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_mlp_tensor_model_parallel_rank()
|
||||
self.enable_mlp_optimze = True
|
||||
self.custom_op, self.tp_rank, self.tp_size = get_column_parallel_op(
|
||||
disable_tp, prefix, self)
|
||||
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
|
||||
self.input_size_per_partition = input_size
|
||||
self.output_size_per_partition = divide(output_size, self.tp_size)
|
||||
self.output_partition_sizes = [self.output_size_per_partition]
|
||||
# If QKV or MergedColumn, use output size of each partition.
|
||||
if hasattr(self, "output_sizes"):
|
||||
self.output_partition_sizes = [
|
||||
divide(output_size, self.tp_size)
|
||||
for output_size in self.output_sizes
|
||||
]
|
||||
|
||||
AscendLinearBase.__init__(self,
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
self.gather_output = gather_output
|
||||
|
||||
if output_sizes is None:
|
||||
output_sizes = [output_size]
|
||||
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size_per_partition,
|
||||
output_partition_sizes=self.output_partition_sizes,
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=(
|
||||
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
||||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.enable_mlp_optimze = False
|
||||
assert all(output_size % self.tp_size == 0
|
||||
for output_size in output_sizes)
|
||||
AscendMlpColumnParallelLinear.__init__(self,
|
||||
input_size=input_size,
|
||||
output_size=sum(output_sizes),
|
||||
bias=bias,
|
||||
gather_output=gather_output,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
return_bias=return_bias)
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
if self.custom_op is not None:
|
||||
self.custom_op.update_attrs()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
# self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
if self.enable_mlp_optimze:
|
||||
input2_ = get_mlp_tp_group().all_gather(input_, 0)
|
||||
output = self.quant_method.apply(self, input2_, bias)
|
||||
else:
|
||||
output_parallel = self.quant_method.apply(self, input_, bias)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
if self.custom_op is not None:
|
||||
return self.custom_op.apply(input_)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
return super().forward(input_)
|
||||
|
||||
459
vllm_ascend/ops/linear_op.py
Normal file
459
vllm_ascend/ops/linear_op.py
Normal file
@@ -0,0 +1,459 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This file extends the functionality of linear operations by encapsulating custom
|
||||
communication groups and forward functions into classes (linear ops).
|
||||
|
||||
Current class inheritance structure:
|
||||
CustomTensorParallelOp
|
||||
├── CustomColumnParallelOp
|
||||
│ ├── MLPColumnParallelOp
|
||||
│ ├── DenseOptimMergedColumnParallelOp
|
||||
│ └── DenseOptimQKVParallelOp
|
||||
└── CustomRowParallelOp
|
||||
├── MLPRowParallelOp
|
||||
├── OProjRowParallelOp
|
||||
├── MatmulAllreduceRowParallelOp
|
||||
└── DenseOptimRowParallelOp
|
||||
|
||||
How to extend a new linear op? Taking column parallel op as an example:
|
||||
1. Inherit from CustomColumnParallelOp and create a new class MyColumnParallelOp
|
||||
2. [Optional] The default communication group is the TP group. If a custom communication group is needed, override the comm_group method
|
||||
3. Override the apply method according to requirements, which will replace the original linear.forward
|
||||
4. Add selection logic for MyColumnParallelOp in the get_column_parallel_op method, typically based on prefix and configuration judgments
|
||||
Row parallel op follows a similar approach - inherit from RowColumnParallelOp and register the new class in get_row_parallel_op.
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.distributed import split_tensor_along_last_dim
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
|
||||
get_otp_group)
|
||||
from vllm_ascend.utils import (dense_optim_enable, enable_sp,
|
||||
matmul_allreduce_enable, mlp_tp_enable,
|
||||
oproj_tp_enable)
|
||||
|
||||
|
||||
class CustomTensorParallelOp:
|
||||
|
||||
def __init__(self, layer):
|
||||
self.layer = layer
|
||||
self.bias = None
|
||||
self.skip_bias_add = None
|
||||
self.return_bias = None
|
||||
self.quant_method = None
|
||||
|
||||
# Custom communication group, while determining weight sharding
|
||||
@property
|
||||
def comm_group(self):
|
||||
return get_tp_group()
|
||||
|
||||
@property
|
||||
def tp_rank(self):
|
||||
return self.comm_group.rank_in_group
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
return self.comm_group.world_size
|
||||
|
||||
# Update the attributes required by apply(), obtaining them from the layer.
|
||||
# Call this after the layer completes its initialization, specifically at the end of layer.init().
|
||||
def update_attrs(self):
|
||||
if hasattr(self.layer, "bias"):
|
||||
self.bias = self.layer.bias
|
||||
self.skip_bias_add = self.layer.skip_bias_add
|
||||
self.return_bias = self.layer.return_bias
|
||||
self.quant_method = self.layer.quant_method
|
||||
self.prefix = self.layer.prefix
|
||||
|
||||
def apply_impl(self, input_):
|
||||
raise NotImplementedError
|
||||
|
||||
# Replace layer.forward to customize the layer computation process.
|
||||
def apply(self, input_):
|
||||
output, output_bias = self.apply_impl(input_)
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class CustomColumnParallelOp(CustomTensorParallelOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
self.gather_output = None
|
||||
|
||||
def update_attrs(self):
|
||||
super().update_attrs()
|
||||
self.gather_output = self.layer.gather_output
|
||||
|
||||
|
||||
class CustomRowParallelOp(CustomTensorParallelOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
self.reduce_results = None
|
||||
self.input_is_parallel = None
|
||||
self.input_size_per_partition = None
|
||||
|
||||
def update_attrs(self):
|
||||
super().update_attrs()
|
||||
self.input_is_parallel = self.layer.input_is_parallel
|
||||
self.reduce_results = self.layer.reduce_results
|
||||
self.input_size_per_partition = self.layer.input_size_per_partition
|
||||
|
||||
def apply(self, input_):
|
||||
output, output_bias = self.apply_impl(input_)
|
||||
if dense_optim_enable():
|
||||
torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix)
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class MLPColumnParallelOp(CustomColumnParallelOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
|
||||
@property
|
||||
def comm_group(self):
|
||||
return get_mlp_tp_group()
|
||||
|
||||
def apply_impl(
|
||||
self,
|
||||
input_: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
input_parallel = self.comm_group.all_gather(input_, 0)
|
||||
output = self.quant_method.apply(self.layer, input_parallel, bias)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class SequenceMergedColumnParallelOp(CustomColumnParallelOp):
|
||||
|
||||
def apply_impl(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
Implemented multiple optimization projects for dense models, such as FlashComm and
|
||||
communication-computation fusion.
|
||||
"""
|
||||
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
|
||||
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
|
||||
output_parallel = self.quant_method.apply(self.layer, input_, bias)
|
||||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = self.comm_group.all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class SequenceQKVParallelOp(CustomColumnParallelOp):
|
||||
|
||||
def __init__(self, layer, prefix):
|
||||
super().__init__(layer)
|
||||
self.prefix = prefix
|
||||
|
||||
def apply_impl(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
Implemented multiple optimization projects for dense models, such as FlashComm and
|
||||
communication-computation fusion.
|
||||
"""
|
||||
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
|
||||
layer_num = self.prefix.split('.')[2]
|
||||
|
||||
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
input_, layer_num != '0')
|
||||
output_parallel = self.quant_method.apply(self.layer, input_, bias)
|
||||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = self.comm_group.all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class MLPRowParallelOp(CustomRowParallelOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
|
||||
@property
|
||||
def comm_group(self):
|
||||
return get_mlp_tp_group()
|
||||
|
||||
def apply_impl(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||
|
||||
assert self.quant_method is not None
|
||||
bias_ = None if (self.tp_rank > 0
|
||||
or self.skip_bias_add) else self.layer.bias
|
||||
output_parallel = self.quant_method.apply(self.layer,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
output = self.comm_group.reduce_scatter(output_parallel, 0)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class OProjRowParallelOp(CustomRowParallelOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
|
||||
@property
|
||||
def comm_group(self):
|
||||
return get_otp_group()
|
||||
|
||||
def apply_impl(
|
||||
self,
|
||||
input_: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||
|
||||
# Prepare tensors for all-to-all communication
|
||||
local_batch_size = input_parallel.size(0)
|
||||
chunk_size = self.input_size_per_partition
|
||||
total_batch_size = local_batch_size * self.tp_size
|
||||
|
||||
# Reshape tensor for efficient cross-device transfer:
|
||||
# [batch, dim] -> [tp_size, batch, chunk] -> flattened
|
||||
send_buf = (input_parallel.reshape(-1,
|
||||
self.tp_size, chunk_size).transpose(
|
||||
0, 1).contiguous().view(-1))
|
||||
|
||||
# Create receive buffer
|
||||
recv_buf = torch.empty(total_batch_size * chunk_size,
|
||||
dtype=input_parallel.dtype,
|
||||
device=input_parallel.device)
|
||||
|
||||
# Perform all-to-all communication
|
||||
dist.all_to_all_single(recv_buf,
|
||||
send_buf,
|
||||
group=self.comm_group.device_group)
|
||||
input_parallel = recv_buf.view(total_batch_size, chunk_size)
|
||||
|
||||
# Only fuse bias add for rank 0 to avoid duplicate bias addition in TP>1
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
assert self.quant_method is not None
|
||||
output_parallel = self.quant_method.apply(self.layer,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
|
||||
# otp-specific: Combine partial results across devices
|
||||
output = self.comm_group.reduce_scatter(output_parallel, dim=0)
|
||||
output = output.view(input_.shape[0], self.layer.output_size)
|
||||
|
||||
# Handle bias return based on configuration
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
def update_attrs(self):
|
||||
super().update_attrs()
|
||||
self.input_is_parallel = self.layer.input_is_parallel
|
||||
self.input_size_per_partition = self.layer.input_size_per_partition
|
||||
|
||||
|
||||
class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
|
||||
_HCOMM_INFO = None
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
self.hcomm_info = self.get_hcomm_info(self.comm_group.device_group)
|
||||
|
||||
def apply_impl(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||
"""Calculate the output tensor of forward by considering
|
||||
fusing communication and computation."""
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output = torch_npu.npu_mm_all_reduce_base(input_parallel,
|
||||
self.weight_t,
|
||||
self.hcomm_info,
|
||||
bias=bias_)
|
||||
else:
|
||||
assert self.quant_method is not None
|
||||
output = self.quant_method.apply(self.layer,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
@classmethod
|
||||
def get_hcomm_info(cls, group: ProcessGroup) -> str:
|
||||
"""Get the HCCL communication information for the given group."""
|
||||
if cls._HCOMM_INFO is not None:
|
||||
return cls._HCOMM_INFO
|
||||
|
||||
rank = torch.distributed.get_rank(group)
|
||||
if torch.__version__ > "2.0":
|
||||
global_rank = torch.distributed.get_global_rank(group, rank)
|
||||
cls._HCOMM_INFO = group._get_backend(
|
||||
torch.device("npu")).get_hccl_comm_name(global_rank)
|
||||
else:
|
||||
cls._HCOMM_INFO = group.get_hccl_comm_name(rank)
|
||||
return cls._HCOMM_INFO
|
||||
|
||||
def update_attrs(self):
|
||||
super().update_attrs()
|
||||
self.weight_t = self.layer.weight.t()
|
||||
|
||||
|
||||
class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
|
||||
def __init__(self, layer, prefix):
|
||||
super().__init__(layer)
|
||||
self.prefix = prefix
|
||||
|
||||
def apply_impl(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
Implemented multiple optimization projects for dense models, such as FlashComm and
|
||||
communication-computation fusion.
|
||||
"""
|
||||
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||
|
||||
assert self.quant_method is not None
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
|
||||
if self.tp_size == 1 or not self.reduce_results:
|
||||
output = self.quant_method.apply(self.layer,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
else:
|
||||
output_parallel = self.quant_method.apply(self.layer,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
def update_attrs(self):
|
||||
super().update_attrs()
|
||||
self.input_is_parallel = self.layer.input_is_parallel
|
||||
self.reduce_results = self.layer.reduce_results
|
||||
|
||||
|
||||
def get_column_parallel_op(
|
||||
disable_tp, prefix, layer
|
||||
) -> Tuple[Optional[Union[MLPColumnParallelOp, SequenceMergedColumnParallelOp,
|
||||
SequenceQKVParallelOp]], int, int]:
|
||||
if disable_tp:
|
||||
return None, 0, 1
|
||||
|
||||
custom_op: Optional[Union[
|
||||
MLPColumnParallelOp,
|
||||
SequenceMergedColumnParallelOp,
|
||||
SequenceQKVParallelOp,
|
||||
]] = None
|
||||
if "gate_up_proj" in prefix and mlp_tp_enable():
|
||||
custom_op = MLPColumnParallelOp(layer)
|
||||
elif "gate_up_proj" in prefix and enable_sp():
|
||||
custom_op = SequenceMergedColumnParallelOp(layer)
|
||||
elif enable_sp():
|
||||
custom_op = SequenceQKVParallelOp(layer, prefix)
|
||||
|
||||
if custom_op is not None:
|
||||
return custom_op, custom_op.tp_rank, custom_op.tp_size
|
||||
|
||||
return None, get_tp_group().rank_in_group, get_tp_group().world_size
|
||||
|
||||
|
||||
def get_row_parallel_op(
|
||||
disable_tp, prefix, layer
|
||||
) -> Tuple[Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
|
||||
MatmulAllreduceRowParallelOp,
|
||||
SequenceRowParallelOp]], int, int]:
|
||||
if disable_tp:
|
||||
return None, 0, 1
|
||||
|
||||
custom_op: Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
|
||||
MatmulAllreduceRowParallelOp,
|
||||
SequenceRowParallelOp]] = None
|
||||
if "down_proj" in prefix and mlp_tp_enable():
|
||||
custom_op = MLPRowParallelOp(layer)
|
||||
elif "o_proj" in prefix and oproj_tp_enable():
|
||||
custom_op = OProjRowParallelOp(layer)
|
||||
elif matmul_allreduce_enable():
|
||||
custom_op = MatmulAllreduceRowParallelOp(layer)
|
||||
elif enable_sp():
|
||||
custom_op = SequenceRowParallelOp(layer, prefix)
|
||||
|
||||
if custom_op is not None:
|
||||
return custom_op, custom_op.tp_rank, custom_op.tp_size
|
||||
|
||||
return None, get_tp_group().rank_in_group, get_tp_group().world_size
|
||||
0
vllm_ascend/ops/moe/__init__.py
Normal file
0
vllm_ascend/ops/moe/__init__.py
Normal file
@@ -1,5 +1,7 @@
|
||||
# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -12,7 +14,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.distributed as dist
|
||||
@@ -60,3 +62,52 @@ def async_all_to_all(input_,
|
||||
group=group,
|
||||
async_op=True)
|
||||
return input_, a2a_out, handle
|
||||
|
||||
|
||||
def _gather_along_first_dim(input_, group, output_split_sizes=None):
|
||||
"""Gather tensors and concatenate along the first dimension.
|
||||
|
||||
Args:
|
||||
input_tensor (torch.Tensor):
|
||||
A tensor to be gathered.
|
||||
output_split_sizes (List[int], optional):
|
||||
A list specifying the sizes of the output splits along the first dimension.
|
||||
If None, equal splitting is assumed. Default: None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Gathered tensor.
|
||||
"""
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
dim_size = list(input_.size())
|
||||
if output_split_sizes is None:
|
||||
dim_size[0] = dim_size[0] * world_size
|
||||
|
||||
output = torch.empty(dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device())
|
||||
torch.distributed.all_gather_into_tensor(output,
|
||||
input_.contiguous(),
|
||||
group=group)
|
||||
else:
|
||||
dim_size[0] = sum(output_split_sizes)
|
||||
output = torch.empty(dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device())
|
||||
output_tensor_list = list(
|
||||
torch.split(output, output_split_sizes, dim=0))
|
||||
torch.distributed.all_gather(output_tensor_list, input_, group=group)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def gather_from_sequence_parallel_region(
|
||||
input_,
|
||||
group,
|
||||
output_split_sizes=None,
|
||||
):
|
||||
"""Wrapper for autograd function: forward: AG, backward: RS <first dim>"""
|
||||
return _gather_along_first_dim(input_, group, output_split_sizes)
|
||||
459
vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py
Normal file
459
vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py
Normal file
@@ -0,0 +1,459 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_dp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization
|
||||
in distributed environments. Subclasses implement specific communication strategies
|
||||
(e.g., AllGather, All2All, MC2, Naive Multicast) to handle tensor padding, slicing,
|
||||
broadcasting, and reduction across TP/DP/EP groups.
|
||||
|
||||
Attributes:
|
||||
moe_config (FusedMoEConfig): Configuration object containing TP/DP/EP group info,
|
||||
sizes, ranks, and communication settings.
|
||||
"""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
self.moe_config = moe_config
|
||||
|
||||
@abstractmethod
|
||||
def prepare(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
rm_router_logits: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Prepare tensors before MoE computation. May involve:
|
||||
- Padding to align communication boundaries
|
||||
- Slicing across tensor-parallel ranks
|
||||
- Broadcasting across data-parallel ranks
|
||||
- Recomputing router logits if needed
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): Input features, shape [num_tokens, hidden_size]
|
||||
router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts]
|
||||
enable_shared_expert_dp (bool): Skip DP communication for shared experts
|
||||
rm_router_logits (bool): Discard input router_logits and recompute via gate
|
||||
replace_allreduce (bool): Bypass default all-reduce behavior
|
||||
gate (nn.Module, optional): Gate network to recompute router_logits if needed
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- processed hidden_states (may be padded/sliced/broadcasted)
|
||||
- processed router_logits (may be recomputed or broadcasted)
|
||||
- optional communication mask (e.g., mc2_mask for sparse ops)
|
||||
"""
|
||||
raise NotImplementedError("Prepare not implemented.")
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalize MoE output. May involve:
|
||||
- Gathering sliced tensors across TP ranks
|
||||
- Reducing or scattering across DP ranks
|
||||
- Unpadding to original token count
|
||||
- Applying all-reduce across TP/EP if requested
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): MoE layer output, possibly padded or sliced
|
||||
reduce_results (bool): Whether to apply all-reduce across TP/EP groups
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Final output with shape [original_num_tokens, hidden_size]
|
||||
"""
|
||||
raise NotImplementedError("Finalize function not implemented.")
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
MoE communication strategy using MC2 (Memory-Centric Communication).
|
||||
Designed for Ascend or environments requiring explicit padding and slicing control.
|
||||
Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment.
|
||||
"""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
super().__init__(moe_config)
|
||||
self._restore_tp_across_dp()
|
||||
|
||||
def _restore_tp_across_dp(self):
|
||||
"""
|
||||
Restore original TP configuration.
|
||||
vLLM flattens TP and DP into a single dimension; this method recovers
|
||||
the true TP world size and rank for correct tensor slicing.
|
||||
"""
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
def prepare(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
rm_router_logits: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Fetch `mc2_mask` and target padding length from forward context.
|
||||
2. Pad `hidden_states` and `router_logits` to target length if needed.
|
||||
3. If TP > 1, split tensors along token dimension and select current TP rank's slice.
|
||||
4. Split and return corresponding `mc2_mask`.
|
||||
|
||||
Skips padding/slicing if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||
|
||||
Returns:
|
||||
Tuple of (hidden_states, router_logits, mc2_mask), possibly sliced/padded.
|
||||
"""
|
||||
self.replace_allreduce = replace_allreduce
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
forward_context = get_forward_context()
|
||||
mc2_mask = forward_context.mc2_mask
|
||||
if self.tp_size > 1:
|
||||
# Also slice mc2_mask
|
||||
split_mc2_mask = torch.tensor_split(mc2_mask, self.tp_size, dim=0)
|
||||
mc2_mask = split_mc2_mask[self.tp_rank]
|
||||
|
||||
if not self.replace_allreduce:
|
||||
self.num_tokens, _ = hidden_states.shape
|
||||
target_pad_length = forward_context.padded_num_tokens
|
||||
pad_size = target_pad_length - self.num_tokens
|
||||
|
||||
# Pad if necessary (unless shared expert DP is enabled)
|
||||
if pad_size > 0 and not self.enable_shared_expert_dp:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
# Slice across TP ranks
|
||||
if self.tp_size > 1 and not self.enable_shared_expert_dp:
|
||||
split_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_router_logits = torch.tensor_split(router_logits,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
self.split_hidden_states = split_hidden_states # Save for finalize
|
||||
|
||||
return hidden_states, router_logits, mc2_mask
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
1. If TP > 1, all-gather slices from all TP ranks to reconstruct full tensor.
|
||||
2. Unpad to original token count if padding was applied.
|
||||
3. Return tensor with shape [original_num_tokens, hidden_size].
|
||||
|
||||
Skips communication and unpadding if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||
"""
|
||||
if not (self.enable_shared_expert_dp or self.replace_allreduce):
|
||||
if self.tp_size > 1:
|
||||
# All-gather across TP group
|
||||
dist.all_gather(list(self.split_hidden_states), hidden_states,
|
||||
self.moe_config.tp_group.device_group)
|
||||
hidden_states = torch.cat(self.split_hidden_states, dim=0)
|
||||
|
||||
# TODO: It is a quick bugfix for the memory explosion issue in eager mode.
|
||||
# If the cache is not cleared after `self.split_hidden_states` is created,
|
||||
# it can lead to the memory explosion in eager mode.
|
||||
del self.split_hidden_states
|
||||
|
||||
# Unpad if necessary
|
||||
if self.num_tokens < hidden_states.shape[0]:
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
MoE communication strategy using All-to-All style slicing.
|
||||
Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing.
|
||||
Will be used when num_tokens exceed mc2's limitation (512 tokens/rank).
|
||||
"""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
super().__init__(moe_config)
|
||||
self._restore_tp_across_dp()
|
||||
|
||||
def _restore_tp_across_dp(self):
|
||||
"""Restore original TP configuration (same as MC2)."""
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
def prepare(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
rm_router_logits: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Pad hidden_states and router_logits to next multiple of TP size.
|
||||
2. If TP > 1, split along token dim and select current TP rank's slice.
|
||||
3. Save splits for later all-gather in finalize.
|
||||
|
||||
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||
|
||||
Returns:
|
||||
Tuple of (hidden_states, router_logits, None) — no mask used in All2All.
|
||||
"""
|
||||
self.replace_allreduce = replace_allreduce
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
|
||||
if not (self.replace_allreduce or self.enable_shared_expert_dp):
|
||||
self.num_tokens, _ = hidden_states.shape
|
||||
pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic)
|
||||
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
if self.tp_size > 1:
|
||||
split_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_router_logits = torch.tensor_split(router_logits,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
self.split_hidden_states = split_hidden_states
|
||||
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
|
||||
return hidden_states, router_logits, None
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
1. If TP > 1, all-gather slices to reconstruct full tensor.
|
||||
2. Unpad to original token count.
|
||||
3. Return [original_num_tokens, hidden_size] tensor.
|
||||
|
||||
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||
"""
|
||||
if not (self.enable_shared_expert_dp or self.replace_allreduce):
|
||||
if self.tp_size > 1:
|
||||
dist.all_gather(list(self.split_hidden_states), hidden_states,
|
||||
self.moe_config.tp_group.device_group)
|
||||
hidden_states = torch.cat(self.split_hidden_states, dim=0)
|
||||
|
||||
# TODO: It is a quick bugfix for the memory explosion issue in eager mode.
|
||||
# If the cache is not cleared after `self.split_hidden_states` is created,
|
||||
# it can lead to the memory explosion in eager mode.
|
||||
del self.split_hidden_states
|
||||
|
||||
if self.num_tokens < hidden_states.shape[0]:
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
MoE communication strategy using All-Gather + Reduce-Scatter.
|
||||
Designed for DP > 1: gather inputs across DP ranks before MoE, scatter outputs after.
|
||||
Uses `max_tokens_across_dp` from forward_context for padding alignment.
|
||||
"""
|
||||
|
||||
def prepare(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
rm_router_logits: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Fetch max token count across DP group from forward context.
|
||||
2. Pad local tensors to that size.
|
||||
3. All-gather across DP group to form global input tensor.
|
||||
4. Optionally recompute router_logits using gate if `rm_router_logits=True`.
|
||||
|
||||
Returns:
|
||||
Tuple of (global_hidden_states, global_router_logits, None)
|
||||
"""
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
|
||||
if self.moe_config.dp_size > 1:
|
||||
forward_context = get_forward_context()
|
||||
max_tokens_across_dp = forward_context.max_tokens_across_dp
|
||||
|
||||
self.num_tokens = hidden_states.shape[0]
|
||||
pad_size = max_tokens_across_dp - self.num_tokens
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
if not rm_router_logits:
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
# All-gather across DP group
|
||||
hidden_states = self.moe_config.dp_group.all_gather(
|
||||
hidden_states, 0)
|
||||
if rm_router_logits:
|
||||
router_logits, _ = gate(hidden_states) # Recompute globally
|
||||
else:
|
||||
router_logits = self.moe_config.dp_group.all_gather(
|
||||
router_logits, 0)
|
||||
|
||||
return hidden_states, router_logits, None
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
1. If DP > 1 and not shared expert, reduce-scatter output across DP group.
|
||||
2. Slice to original local token count.
|
||||
3. If `reduce_results=True` and TP/EP > 1, apply tensor_model_parallel_all_reduce.
|
||||
|
||||
Returns:
|
||||
Tensor with shape [original_local_num_tokens, hidden_size]
|
||||
"""
|
||||
if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp:
|
||||
hidden_states = get_dp_group().reduce_scatter(hidden_states, 0)
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
|
||||
if reduce_results and (self.moe_config.tp_size > 1
|
||||
or self.moe_config.ep_size > 1):
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
MoE communication strategy using Naive Multicast (point-to-point broadcast).
|
||||
Will be used in prefill when using allgather in decode. Each DP rank broadcasts its slice to all others.
|
||||
Uses `cu_tokens_across_dp_cpu` (cumulative tokens) to locate slice boundaries.
|
||||
"""
|
||||
|
||||
def _naive_multicast(self, x: torch.Tensor,
|
||||
cu_tokens_across_dp_cpu: torch.Tensor):
|
||||
"""
|
||||
Naive multicast implementation:
|
||||
1. Create global buffer sized by total tokens across DP.
|
||||
2. Current rank copies its slice into its designated buffer region.
|
||||
3. Each rank broadcasts its slice to all others via P2P.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Local tensor [local_tokens, hidden_size]
|
||||
cu_tokens_across_dp_cpu (torch.Tensor): Cumulative token counts per DP rank
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Global tensor [total_tokens, hidden_size]
|
||||
"""
|
||||
assert len(x.shape) == 2, "Input must be 2D [tokens, features]"
|
||||
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
|
||||
# Copy local slice into buffer
|
||||
start = 0 if self.moe_config.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
self.moe_config.dp_rank - 1]
|
||||
end = cu_tokens_across_dp_cpu[self.moe_config.dp_rank]
|
||||
buffer[start:end, :].copy_(x)
|
||||
|
||||
# Broadcast each slice to all ranks
|
||||
for idx in range(self.moe_config.dp_size):
|
||||
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
|
||||
end = cu_tokens_across_dp_cpu[idx]
|
||||
get_dp_group().broadcast(buffer[start:end, :], idx)
|
||||
return buffer
|
||||
|
||||
def prepare(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
rm_router_logits: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Fetch cumulative token boundaries from forward context.
|
||||
2. Multicast hidden_states and router_logits to form global tensors.
|
||||
3. Optionally recompute router_logits globally if `rm_router_logits=True`.
|
||||
|
||||
Returns:
|
||||
Tuple of (global_hidden_states, global_router_logits, None)
|
||||
"""
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
|
||||
if self.moe_config.dp_size > 1:
|
||||
if vllm_version_is("0.10.2"):
|
||||
self.cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_dp_cpu
|
||||
else:
|
||||
self.cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_sp(1)
|
||||
hidden_states = self._naive_multicast(hidden_states,
|
||||
self.cu_tokens_across_dp_cpu)
|
||||
if rm_router_logits:
|
||||
router_logits, _ = gate(hidden_states)
|
||||
else:
|
||||
router_logits = self._naive_multicast(
|
||||
router_logits, self.cu_tokens_across_dp_cpu)
|
||||
|
||||
return hidden_states, router_logits, None
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
1. If DP > 1 and not shared expert:
|
||||
- All-reduce across DP
|
||||
- Slice to current rank's token range using cu_tokens_across_dp_cpu
|
||||
2. If `reduce_results=True` and TP/EP > 1, apply tensor_model_parallel_all_reduce.
|
||||
|
||||
Returns:
|
||||
Tensor with shape [local_num_tokens, hidden_size]
|
||||
"""
|
||||
if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp:
|
||||
start = 0 if self.moe_config.dp_rank == 0 else self.cu_tokens_across_dp_cpu[
|
||||
self.moe_config.dp_rank - 1]
|
||||
end = self.cu_tokens_across_dp_cpu[self.moe_config.dp_rank]
|
||||
hidden_states = get_dp_group().all_reduce(
|
||||
hidden_states) # Sum across DP
|
||||
hidden_states = hidden_states[start:end, :]
|
||||
|
||||
if reduce_results and (self.moe_config.tp_size > 1
|
||||
or self.moe_config.ep_size > 1):
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
273
vllm_ascend/ops/moe/moe_comm_method.py
Normal file
273
vllm_ascend/ops/moe/moe_comm_method.py
Normal file
@@ -0,0 +1,273 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
|
||||
FusedMoEPrepareAndFinalizeWithAll2All,
|
||||
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
|
||||
FusedMoEPrepareAndFinalizeWithNaiveMulticast)
|
||||
from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.moe.token_dispatcher import (TokenDispatcherWithAll2AllV,
|
||||
TokenDispatcherWithAllGather,
|
||||
TokenDispatcherWithMC2,
|
||||
TokenDispatcherWithMoge)
|
||||
|
||||
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
|
||||
|
||||
|
||||
def get_moe_comm_method(
|
||||
moe_comm_type: Optional[MoECommType]) -> Optional[MoECommMethod]:
|
||||
return _MoECommMethods.get(moe_comm_type)
|
||||
|
||||
|
||||
def setup_moe_comm_method(moe_config):
|
||||
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config)
|
||||
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config)
|
||||
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
|
||||
_MoECommMethods[MoECommType.NAIVE_MULTICAST] = NaiveMulticastCommImpl(
|
||||
moe_config)
|
||||
|
||||
|
||||
class MoECommMethod(ABC):
|
||||
"""Base class for MoE communication methods."""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
self.model_type = get_current_vllm_config(
|
||||
).model_config.hf_config.model_type
|
||||
self.moe_config = moe_config
|
||||
self.mc2_mask = None
|
||||
|
||||
self.token_dispatcher = self._get_token_dispatcher()
|
||||
self.fused_moe_prepare_finalize = self._get_fused_moe_prepare_finalize(
|
||||
)
|
||||
|
||||
def prepare(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
rm_router_logits: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
hidden_states, router_logits, mc2_mask = self.fused_moe_prepare_finalize.prepare(
|
||||
hidden_states, router_logits, enable_shared_expert_dp,
|
||||
rm_router_logits, replace_allreduce, gate)
|
||||
self.mc2_mask = mc2_mask
|
||||
return hidden_states, router_logits
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
hidden_states = self.fused_moe_prepare_finalize.finalize(
|
||||
hidden_states, reduce_results)
|
||||
return hidden_states
|
||||
|
||||
def fused_experts(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int4_w4a8: bool = False,
|
||||
global_num_experts: Optional[int] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
# For TorchAir graph
|
||||
is_torchair: bool = False,
|
||||
# For Cube/Vector parallel
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
# For load balance
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
need_trans: bool = False,
|
||||
dynamic_eplb: bool = False):
|
||||
# Check constraints
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
assert moe_comm_method is not None, "Missing communication context"
|
||||
|
||||
results = self.token_dispatcher.token_dispatch(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
shared_experts=shared_experts,
|
||||
quantized_x_for_share=quantized_x_for_share,
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
mc2_mask=self.mc2_mask,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
with_quant=use_int8_w8a8 or use_int4_w4a8)
|
||||
|
||||
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales = \
|
||||
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales")
|
||||
|
||||
mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=expert_tokens,
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list_type=group_list_type,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
topk_scales=topk_scales,
|
||||
with_quant=use_int8_w8a8
|
||||
or use_int4_w4a8,
|
||||
fusion=use_int8_w8a8,
|
||||
need_trans=need_trans)
|
||||
|
||||
final_hidden_states = self.token_dispatcher.token_combine(
|
||||
hidden_states=mlp_output)
|
||||
|
||||
if dynamic_eplb:
|
||||
return (final_hidden_states, group_list_type, expert_tokens)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
@abstractmethod
|
||||
def _get_token_dispatcher(self):
|
||||
raise NotImplementedError(
|
||||
"_get_token_dispatcher function not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def _get_fused_moe_prepare_finalize(self):
|
||||
raise NotImplementedError(
|
||||
"_get_fused_moe_prepare_finalize function not implemented.")
|
||||
|
||||
|
||||
class AllGatherCommImpl(MoECommMethod):
|
||||
"""This implementation is the same as NativeAllGatherCommImpl,
|
||||
but uses NPU-specific ops for better performance.
|
||||
|
||||
This implementation should be compatible with all scenarios, and
|
||||
thus it is the default implementation for MoE communication methods.
|
||||
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
|
||||
and `torch_npu.npu_moe_token_unpermute` for post-processing
|
||||
to handle the token-to-expert mapping and communication efficiently.
|
||||
|
||||
NOTE(Yizhou): TBH, it is really weird that we were supposed to use
|
||||
`torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing`
|
||||
or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute`
|
||||
for pre-processing and post-processing, respectively.
|
||||
But `npu_moe_finalize_routing` will lead to accuracy issues so we have to
|
||||
use `torch_npu.npu_moe_token_unpermute` instead.
|
||||
This is a workaround and should be removed after the issue is fixed.
|
||||
"""
|
||||
|
||||
def _get_token_dispatcher(self):
|
||||
if self.model_type == "PanguProMoE":
|
||||
return TokenDispatcherWithMoge(
|
||||
top_k=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
else:
|
||||
return TokenDispatcherWithAllGather(
|
||||
top_k=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
|
||||
def _get_fused_moe_prepare_finalize(self):
|
||||
return FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config)
|
||||
|
||||
|
||||
class MC2CommImpl(MoECommMethod):
|
||||
"""This implementation is for the scenarios listed below:
|
||||
1. `enable_expert_parallel=True`.
|
||||
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
|
||||
3. `enable_expert_parallel=False` is not supported.
|
||||
|
||||
This implementation uses the MC2 communication method, which is optimized for
|
||||
Communication and Computation parallelism on Ascend devices.
|
||||
"""
|
||||
|
||||
def _get_token_dispatcher(self):
|
||||
return TokenDispatcherWithMC2()
|
||||
|
||||
def _get_fused_moe_prepare_finalize(self):
|
||||
return FusedMoEPrepareAndFinalizeWithMC2(self.moe_config)
|
||||
|
||||
|
||||
class AlltoAllCommImpl(MoECommMethod):
|
||||
"""This implementation is for the scenarios listed below:
|
||||
1. `enable_expert_parallel=True`.
|
||||
2. `npu_grouped_matmul` is available.
|
||||
|
||||
This implementation uses all-to-all communication to exchange tokens
|
||||
between data parallel ranks before and after the MLP computation. It should
|
||||
have better performance than AllGatherCommImpl when DP size > 1.
|
||||
"""
|
||||
|
||||
def _get_token_dispatcher(self):
|
||||
return TokenDispatcherWithAll2AllV(
|
||||
top_k=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
|
||||
def _get_fused_moe_prepare_finalize(self):
|
||||
return FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
|
||||
|
||||
class NaiveMulticastCommImpl(MoECommMethod):
|
||||
"""This implementation is the same as NativeAllGatherCommImpl,
|
||||
but uses NPU-specific ops for better performance.
|
||||
|
||||
This implementation should be compatible with all scenarios, and
|
||||
thus it is the default implementation for MoE communication methods.
|
||||
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
|
||||
and `torch_npu.npu_moe_token_unpermute` for post-processing
|
||||
to handle the token-to-expert mapping and communication efficiently.
|
||||
|
||||
NOTE(Yizhou): TBH, it is really weird that we were supposed to use
|
||||
`torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing`
|
||||
or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute`
|
||||
for pre-processing and post-processing, respectively.
|
||||
But `npu_moe_finalize_routing` will lead to accuracy issues so we have to
|
||||
use `torch_npu.npu_moe_token_unpermute` instead.
|
||||
This is a workaround and should be removed after the issue is fixed.
|
||||
"""
|
||||
|
||||
def _get_token_dispatcher(self):
|
||||
return TokenDispatcherWithAllGather(
|
||||
top_k=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
|
||||
def _get_fused_moe_prepare_finalize(self):
|
||||
return FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config)
|
||||
@@ -18,22 +18,52 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch.nn.functional import pad
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.utils import dispose_tensor, is_310p
|
||||
|
||||
|
||||
def cumsum_group_list(group_list: torch.Tensor,
|
||||
group_list_type: int,
|
||||
active_num: int = 0,
|
||||
expert_num: int = 0) -> torch.Tensor:
|
||||
if group_list_type not in [0, 1, 2]:
|
||||
raise ValueError(
|
||||
f"group_list_type should be in [0, 1, 2], but received {group_list_type}"
|
||||
)
|
||||
|
||||
if group_list_type == 0:
|
||||
return group_list
|
||||
if group_list_type == 1:
|
||||
return group_list.cumsum(dim=0)
|
||||
|
||||
experts = pad(group_list[:, 0], (1, 0))
|
||||
tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0))
|
||||
cumsum_group_list = torch.full(size=(expert_num, ),
|
||||
fill_value=active_num,
|
||||
dtype=group_list.dtype,
|
||||
device=group_list.device)
|
||||
|
||||
for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])):
|
||||
if end > start:
|
||||
cumsum_group_list[start:end] = tokens[i]
|
||||
|
||||
return cumsum_group_list
|
||||
|
||||
|
||||
def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
group_list_type: int = 1,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None) -> torch.Tensor:
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
fusion: bool = False) -> torch.Tensor:
|
||||
if dynamic_scale is None:
|
||||
unquantized_hidden_states = hidden_states
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
||||
@@ -47,33 +77,40 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
bias1, bias2 = None, None
|
||||
_output_dtype = w2_scale.dtype
|
||||
|
||||
is_mc2 = get_forward_context().fused_moe_state == FusedMoEState.MC2
|
||||
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
|
||||
if w1_scale_bias is None and is_mc2:
|
||||
w1_scale = w1_scale.to(torch.float32)
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=3,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=torch.int32)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight_scale=w1_scale,
|
||||
activation_scale=pertoken_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=group_list,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
|
||||
if fusion:
|
||||
# gmm1: gate_up_proj & act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight=w1,
|
||||
group_list=cumsum_group_list(group_list, group_list_type),
|
||||
weight_scale=w1_scale,
|
||||
x_scale=pertoken_scale)
|
||||
else:
|
||||
if w1_scale.dtype != torch.float32:
|
||||
w1_scale = w1_scale.to(torch.float32)
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=3,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=torch.int32)[0]
|
||||
# act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight_scale=w1_scale,
|
||||
activation_scale=pertoken_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=group_list,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
@@ -92,29 +129,37 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
[group_list[:1],
|
||||
torch.diff(group_list, dim=0)])
|
||||
group_list_type = 1
|
||||
bias1 = [w1_scale_bias]
|
||||
bias1 = [w1_scale_bias] if not fusion else w1_scale_bias
|
||||
bias2 = [w2_scale_bias]
|
||||
# TODO w4a8 scene: dynamic acquisition of dtype in the future
|
||||
_output_dtype = torch.bfloat16
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
scale=[w1_scale],
|
||||
bias=bias1,
|
||||
per_token_scale=[pertoken_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
|
||||
if fusion:
|
||||
# gmm1: gate_up_proj & act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight=w1,
|
||||
bias=bias1,
|
||||
group_list=cumsum_group_list(group_list, group_list_type),
|
||||
weight_scale=w1_scale,
|
||||
x_scale=pertoken_scale)
|
||||
else:
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
scale=[w1_scale.to(w2_scale.dtype)],
|
||||
bias=bias1,
|
||||
per_token_scale=[pertoken_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
# act_fn: swiglu
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
@@ -127,17 +172,22 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unquant_apply_mlp(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
topk_scales: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
w1 = w1.transpose(1, 2)
|
||||
def unquant_apply_mlp(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
topk_scales: Optional[torch.Tensor] = None,
|
||||
need_trans: bool = True) -> torch.Tensor:
|
||||
|
||||
if need_trans:
|
||||
w1 = w1.transpose(1, 2)
|
||||
w2 = w2.transpose(1, 2)
|
||||
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
@@ -155,7 +205,6 @@ def unquant_apply_mlp(
|
||||
if topk_scales is not None:
|
||||
gate_up_out *= topk_scales
|
||||
|
||||
w2 = w2.transpose(1, 2)
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[gate_up_out],
|
||||
weight=[w2],
|
||||
@@ -178,7 +227,9 @@ def unified_apply_mlp(hidden_states: torch.Tensor,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
topk_scales: Optional[torch.Tensor] = None,
|
||||
with_quant: bool = False) -> torch.Tensor:
|
||||
with_quant: bool = False,
|
||||
fusion: bool = False,
|
||||
need_trans: bool = True) -> torch.Tensor:
|
||||
if with_quant:
|
||||
return quant_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
@@ -189,11 +240,13 @@ def unified_apply_mlp(hidden_states: torch.Tensor,
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list_type=group_list_type,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias)
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
fusion=fusion)
|
||||
else:
|
||||
return unquant_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type,
|
||||
topk_scales=topk_scales)
|
||||
topk_scales=topk_scales,
|
||||
need_trans=need_trans)
|
||||
@@ -22,42 +22,17 @@
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.distributed.tensor_parallel import \
|
||||
gather_from_sequence_parallel_region
|
||||
from vllm_ascend.ops.comm_utils import async_all_to_all
|
||||
from vllm_ascend.ops.moe.comm_utils import (
|
||||
async_all_to_all, gather_from_sequence_parallel_region)
|
||||
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
|
||||
|
||||
_Dispatchers: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def _register_token_dispatcher(dispatcher: Any):
|
||||
_Dispatchers[dispatcher.__class__.__name__] = dispatcher
|
||||
|
||||
|
||||
def get_token_dispatcher(name: str):
|
||||
return _Dispatchers.get(name)
|
||||
|
||||
|
||||
def setup_token_dispatchers(ep_size: int, **kwargs):
|
||||
existing_dispatchers = set(_Dispatchers.keys())
|
||||
|
||||
if ep_size == 1 and "TokenDispatcherWithAllGather" not in existing_dispatchers:
|
||||
_register_token_dispatcher(TokenDispatcherWithAllGather(**kwargs))
|
||||
elif ep_size < 16 and "TokenDispatcherWithAll2AllV" not in existing_dispatchers:
|
||||
_register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs))
|
||||
elif ep_size >= 16:
|
||||
if "TokenDispatcherWithAll2AllV" not in existing_dispatchers:
|
||||
_register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs))
|
||||
if "TokenDispatcherWithMC2" not in existing_dispatchers:
|
||||
_register_token_dispatcher(TokenDispatcherWithMC2(**kwargs))
|
||||
|
||||
|
||||
class MoETokenDispatcher(ABC):
|
||||
|
||||
@@ -90,9 +65,9 @@ class MoETokenDispatcher(ABC):
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[torch.Tensor] = None,
|
||||
shared_gate_up: Optional[torch.Tensor] = None,
|
||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False):
|
||||
@@ -158,6 +133,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": moe_expert_num,
|
||||
"global_bs": 0,
|
||||
"expert_token_nums_type": 0,
|
||||
}
|
||||
|
||||
stage1_kwargs = {
|
||||
@@ -189,9 +165,9 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[torch.Tensor] = None,
|
||||
shared_gate_up: Optional[torch.Tensor] = None,
|
||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False):
|
||||
@@ -215,6 +191,11 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
|
||||
if self.with_quant:
|
||||
if shared_experts is not None:
|
||||
share_up_out, _ = shared_experts.gate_up_proj(
|
||||
(quantized_x_for_share, dynamic_scale_for_share))
|
||||
shared_gate_up, shared_dequant_scale = share_up_out[
|
||||
0], share_up_out[1]
|
||||
|
||||
shared_act_out = shared_experts.act_fn(
|
||||
(shared_gate_up, shared_dequant_scale))
|
||||
self.shared_act, self.swiglu_out_scale = \
|
||||
@@ -224,7 +205,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
if shared_experts is not None:
|
||||
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
|
||||
self.shared_act = shared_experts.act_fn(shared_gate_up)
|
||||
group_list_type = 1
|
||||
group_list_type = 0
|
||||
return {
|
||||
"group_list_type": group_list_type,
|
||||
"hidden_states": expand_x,
|
||||
@@ -291,6 +272,16 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
**kwargs_mc2
|
||||
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
|
||||
**kwargs_mc2)
|
||||
|
||||
# these values are no longer used, so they need to be set to None for memory release.
|
||||
self.output = None
|
||||
self.assist_info_for_combine = None
|
||||
self.ep_recv_counts = None
|
||||
self.topk_ids = None
|
||||
self.topk_weights = None
|
||||
self.mc2_mask = None
|
||||
self.expert_map = None
|
||||
|
||||
if self.shared_experts is None:
|
||||
return hidden_states
|
||||
else:
|
||||
@@ -300,6 +291,9 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
else:
|
||||
shared_hidden_states, _ = self.shared_experts.down_proj(
|
||||
self.shared_act)
|
||||
self.shared_act = None
|
||||
self.shared_experts = None
|
||||
self.swiglu_out_scale = None
|
||||
return hidden_states, shared_hidden_states
|
||||
|
||||
|
||||
@@ -328,9 +322,9 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[torch.Tensor] = None,
|
||||
shared_gate_up: Optional[torch.Tensor] = None,
|
||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False):
|
||||
@@ -338,8 +332,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
self.original_shape = hidden_states.shape
|
||||
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
dtype = hidden_states.dtype
|
||||
device = hidden_states.device
|
||||
self.expert_map = expert_map
|
||||
self.topk_weights = topk_weights
|
||||
self.topk_ids = topk_ids
|
||||
@@ -353,144 +345,65 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
hidden_states = hidden_states * \
|
||||
topk_weights.to(hidden_states.dtype)
|
||||
|
||||
if expert_map is not None:
|
||||
# Generate token indices and flatten
|
||||
token_indices = (torch.arange(
|
||||
num_tokens, device=device,
|
||||
dtype=torch.int64).unsqueeze(1).expand(-1,
|
||||
self.top_k).reshape(-1))
|
||||
|
||||
# Flatten token-to-expert mappings and map to local experts
|
||||
weights_flat = topk_weights.view(-1)
|
||||
experts_flat = topk_ids.view(-1)
|
||||
local_experts_flat = expert_map[experts_flat]
|
||||
|
||||
# Filter valid token-expert pairs
|
||||
self.mask = local_experts_flat != -1
|
||||
filtered_weights = torch.where(
|
||||
self.mask, weights_flat,
|
||||
torch.zeros_like(weights_flat)).to(dtype)
|
||||
filtered_experts = torch.where(
|
||||
self.mask, local_experts_flat,
|
||||
torch.full_like(local_experts_flat,
|
||||
self.num_experts_local)).to(topk_ids.dtype)
|
||||
|
||||
# Sort by local expert IDs
|
||||
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
|
||||
self.sorted_token_indices = token_indices[sort_indices]
|
||||
self.sorted_weights = filtered_weights[sort_indices]
|
||||
|
||||
# Compute token counts with minlength of num_experts
|
||||
# This is equivalent to but faster than:
|
||||
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
|
||||
token_counts = torch.zeros(self.num_experts_local + 1,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
|
||||
token_counts.scatter_add_(0, filtered_experts.to(torch.int64),
|
||||
ones)
|
||||
token_counts = token_counts[:self.num_experts_local]
|
||||
|
||||
# Rearrange hidden_states
|
||||
sorted_hidden_states = hidden_states[self.sorted_token_indices]
|
||||
if self.with_quant:
|
||||
group_list_type = 1
|
||||
expert_tokens = token_counts
|
||||
else:
|
||||
expert_tokens = torch.cumsum(token_counts,
|
||||
dim=0,
|
||||
dtype=torch.int64)
|
||||
group_list_type = 0
|
||||
global_num_experts = len(expert_map)
|
||||
mask = (expert_map[topk_ids] != -1)
|
||||
self.topk_weights = topk_weights * mask
|
||||
first_expert_idx = get_ep_group(
|
||||
).rank_in_group * self.num_experts_local
|
||||
last_expert_idx = first_expert_idx + self.num_experts_local
|
||||
else:
|
||||
active_num = self.max_num_tokens if self.max_num_tokens is not None else num_tokens
|
||||
sorted_hidden_states, self.expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
row_idx=row_idx,
|
||||
expert_idx=topk_ids,
|
||||
active_num=active_num)
|
||||
first_expert_idx = 0
|
||||
last_expert_idx = self.num_experts_local
|
||||
global_num_experts = self.num_experts_local
|
||||
|
||||
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
||||
expanded_expert_idx, self.num_experts_local)
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
group_list_type = 0
|
||||
sorted_hidden_states, self.expanded_row_idx, expert_tokens, pertoken_scale = (
|
||||
torch_npu.npu_moe_init_routing_v2(
|
||||
hidden_states,
|
||||
topk_ids,
|
||||
active_num=num_tokens * self.top_k,
|
||||
expert_num=global_num_experts,
|
||||
expert_tokens_num_type=1,
|
||||
expert_tokens_num_flag=True,
|
||||
active_expert_range=[first_expert_idx, last_expert_idx],
|
||||
quant_mode=1 if self.with_quant else -1,
|
||||
))
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
group_list_type = 1 # `count` mode
|
||||
return {
|
||||
"group_list_type": group_list_type,
|
||||
"hidden_states": sorted_hidden_states,
|
||||
"group_list": expert_tokens,
|
||||
"dynamic_scale": pertoken_scale if self.with_quant else None,
|
||||
}
|
||||
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
assert self.original_shape is not None
|
||||
dtype = hidden_states.dtype
|
||||
device = hidden_states.device
|
||||
if self.expert_map is not None:
|
||||
assert self.mask is not None
|
||||
assert self.sorted_token_indices is not None
|
||||
assert self.sorted_weights is not None
|
||||
final_hidden_states = torch_npu.npu_moe_token_unpermute(
|
||||
permuted_tokens=hidden_states,
|
||||
sorted_indices=self.expanded_row_idx,
|
||||
probs=self.topk_weights)
|
||||
if len(self.original_shape) == 3:
|
||||
final_hidden_states = final_hidden_states.view(self.original_shape)
|
||||
|
||||
weighted_down_out = hidden_states * \
|
||||
self.sorted_weights.unsqueeze(1)
|
||||
|
||||
final_hidden_states = torch.zeros(*self.original_shape,
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
# TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
||||
# This created multiple NaN and index_add_ will mix them up which harms accuracy
|
||||
# remove this mask and filter after it being fixed
|
||||
num_valid_tokens = self.mask.sum()
|
||||
valid_token_mask = torch.arange(
|
||||
0, self.sorted_token_indices.shape[0],
|
||||
device=device).unsqueeze(1) < num_valid_tokens
|
||||
valid_output = torch.where(
|
||||
valid_token_mask, weighted_down_out,
|
||||
torch.zeros_like(weighted_down_out)).to(dtype)
|
||||
final_hidden_states.index_add_(0, self.sorted_token_indices,
|
||||
valid_output)
|
||||
else:
|
||||
if self.with_quant:
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
hidden_states,
|
||||
skip1=None,
|
||||
skip2=None,
|
||||
bias=None,
|
||||
scales=self.topk_weights,
|
||||
expanded_src_to_dst_row=self.expanded_row_idx,
|
||||
export_for_source_row=self.topk_ids,
|
||||
)
|
||||
if len(self.original_shape) == 3:
|
||||
final_hidden_states = final_hidden_states.view(
|
||||
self.original_shape)
|
||||
else:
|
||||
scales = torch.ones_like(
|
||||
self.topk_weights
|
||||
) if self.apply_router_weight_on_input else self.topk_weights
|
||||
# TODO: Reorder device memory 2 times here, replace the current
|
||||
# implementation here when suitable operators become available.
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
hidden_states,
|
||||
skip1=None,
|
||||
skip2=None,
|
||||
bias=None,
|
||||
scales=scales,
|
||||
expanded_src_to_dst_row=self.expanded_row_idx,
|
||||
export_for_source_row=self.topk_ids,
|
||||
)
|
||||
# these values are no longer used, so they need to be set to None for memory release.
|
||||
self.expert_map = None
|
||||
self.topk_weights = None
|
||||
self.topk_ids = None
|
||||
self.expanded_row_idx = None
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
# mypy: disable-error-code="override"
|
||||
class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
|
||||
class TokenDispatcherWithMoge(MoETokenDispatcher):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.apply_router_weight_on_input = False
|
||||
self.local_ep = 1
|
||||
self.local_num_experts = self.num_experts // self.local_ep
|
||||
self.local_num_group = self.top_k // self.local_ep
|
||||
self.local_num_experts = self.num_experts // self.ep_size
|
||||
self.local_num_group = self.top_k // self.ep_size
|
||||
self.bsz = None
|
||||
|
||||
def token_dispatch(self,
|
||||
@@ -501,23 +414,12 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[torch.Tensor] = None,
|
||||
shared_gate_up: Optional[torch.Tensor] = None,
|
||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False):
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
if self.apply_router_weight_on_input:
|
||||
assert (topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert (
|
||||
topk == 1
|
||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
hidden_states = hidden_states * \
|
||||
topk_weights.to(hidden_states.dtype)
|
||||
|
||||
self.bsz, _ = hidden_states.shape
|
||||
flatten_topk_ids = topk_ids.view(-1)
|
||||
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
||||
@@ -551,7 +453,7 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
|
||||
unsorted_hidden_states = hidden_states.index_select(
|
||||
0, unsorted_topk_ids)
|
||||
final_hidden_states = unsorted_hidden_states.reshape(
|
||||
self.bsz, self.top_k // self.local_ep, -1).sum(1)
|
||||
self.bsz, self.top_k // self.ep_size, -1).sum(1)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
@@ -613,9 +515,9 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[torch.Tensor] = None,
|
||||
shared_gate_up: Optional[torch.Tensor] = None,
|
||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False):
|
||||
@@ -681,9 +583,14 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
|
||||
output = self._combine_postprocess(permutated_local_input_tokens)
|
||||
|
||||
# these values are no longer used, so they need to be set to None for memory release.
|
||||
self.input_splits = None
|
||||
self.output_splits = None
|
||||
self.num_global_tokens_per_local_expert = None
|
||||
self.topk_weights = None
|
||||
self.reversed_local_input_permutation_mapping = None
|
||||
self.reversed_global_input_permutation_mapping = None
|
||||
self.global_input_tokens_local_experts_indices = None
|
||||
|
||||
return output
|
||||
|
||||
@@ -745,6 +652,10 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
|
||||
self.expert_ids_per_ep_rank,
|
||||
self.num_global_tokens_per_local_expert.ravel())
|
||||
else:
|
||||
# TODO: This full synchronization can be a performance bottleneck.
|
||||
# A more granular sync (e.g., blocking D2H copies) should be investigated.
|
||||
torch.npu.synchronize()
|
||||
|
||||
return num_tokens_per_local_expert
|
||||
|
||||
201
vllm_ascend/ops/register_custom_ops.py
Normal file
201
vllm_ascend/ops/register_custom_ops.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_reduce_scatter)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
|
||||
|
||||
def _maybe_chunk_residual_impl(x: torch.Tensor,
|
||||
residual: torch.Tensor) -> torch.Tensor:
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
except AssertionError:
|
||||
return residual
|
||||
|
||||
if x.size(0) != residual.size(0):
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
assert sp_enabled is True, ("Currently, this situation only occurs "
|
||||
"when sp is enabled")
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0:
|
||||
residual = F.pad(residual, (0, 0, 0, pad_size))
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
residual = torch.chunk(residual, tp_size, dim=0)[tp_rank]
|
||||
|
||||
return residual
|
||||
|
||||
|
||||
def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
|
||||
label: bool) -> torch.Tensor:
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
except AssertionError:
|
||||
return x
|
||||
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
if sp_enabled and label:
|
||||
x = tensor_model_parallel_all_gather(x, 0)
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0:
|
||||
x = x[:-pad_size, :]
|
||||
return x
|
||||
|
||||
|
||||
def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
except AssertionError:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
if sp_enabled:
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0:
|
||||
x = F.pad(x, (0, 0, 0, pad_size))
|
||||
return tensor_model_parallel_reduce_scatter(x, 0)
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
|
||||
def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
|
||||
prefix: str) -> None:
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
except AssertionError:
|
||||
return
|
||||
|
||||
if not forward_context.prefetch_mlp_enabled:
|
||||
return
|
||||
model_instance = forward_context.model_instance
|
||||
prefetch_stream = forward_context.prefetch_stream
|
||||
layer_idx = int(prefix.split('.')[2])
|
||||
|
||||
# start point of gate_up_proj weight prefetch
|
||||
if prefix.split('.')[-2] == "self_attn":
|
||||
forward_context.prefetch_mlp_gate_up_proj = True
|
||||
if forward_context.prefetch_mlp_gate_up_proj:
|
||||
prefetch_stream.wait_stream(torch.npu.current_stream())
|
||||
|
||||
with torch.npu.stream(prefetch_stream):
|
||||
mlp_gate_up_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE
|
||||
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight, \
|
||||
x_dependency, mlp_gate_up_prefetch_size)
|
||||
return
|
||||
|
||||
|
||||
def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor,
|
||||
prefix: str) -> None:
|
||||
return
|
||||
|
||||
|
||||
def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
except AssertionError:
|
||||
return
|
||||
|
||||
if not forward_context.prefetch_mlp_enabled:
|
||||
return
|
||||
forward_context.prefetch_mlp_down_proj = True
|
||||
model_instance = forward_context.model_instance
|
||||
prefetch_stream = forward_context.prefetch_stream
|
||||
layer_idx = forward_context.layer_idx
|
||||
|
||||
# start point of down_proj weight prefetch
|
||||
prefetch_stream.wait_stream(torch.npu.current_stream())
|
||||
|
||||
with torch.npu.stream(prefetch_stream):
|
||||
mlp_down_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE
|
||||
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.down_proj.weight, \
|
||||
x_dependency, mlp_down_prefetch_size)
|
||||
forward_context.layer_idx += 1
|
||||
return
|
||||
|
||||
|
||||
def _maybe_prefetch_mlp_down_proj_impl_fake(
|
||||
x_dependency: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None:
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
except AssertionError:
|
||||
return
|
||||
|
||||
if not forward_context.prefetch_mlp_enabled:
|
||||
return
|
||||
if forward_context.prefetch_mlp_gate_up_proj or \
|
||||
forward_context.prefetch_mlp_down_proj:
|
||||
prefetch_stream = forward_context.prefetch_stream
|
||||
# wait until prefetch done
|
||||
torch.npu.current_stream().wait_stream(prefetch_stream)
|
||||
forward_context.prefetch_mlp_gate_up_proj = False
|
||||
forward_context.prefetch_mlp_down_proj = False
|
||||
return
|
||||
|
||||
|
||||
def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
def _maybe_all_reduce_tensor_model_parallel_impl(
|
||||
final_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
|
||||
return final_hidden_states
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
|
||||
direct_register_custom_op(op_name="maybe_chunk_residual",
|
||||
op_func=_maybe_chunk_residual_impl,
|
||||
fake_impl=lambda x, residual: residual,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad",
|
||||
op_func=_maybe_all_gather_and_maybe_unpad_impl,
|
||||
fake_impl=lambda x, label: x,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="maybe_pad_and_reduce",
|
||||
op_func=_maybe_pad_and_reduce_impl,
|
||||
fake_impl=lambda x: x,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="maybe_prefetch_mlp_gate_up_proj",
|
||||
op_func=_maybe_prefetch_mlp_gate_up_proj_impl,
|
||||
fake_impl=_maybe_prefetch_mlp_gate_up_proj_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="maybe_prefetch_mlp_down_proj",
|
||||
op_func=_maybe_prefetch_mlp_down_proj_impl,
|
||||
fake_impl=_maybe_prefetch_mlp_down_proj_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="maybe_wait_prefetch_done",
|
||||
op_func=_maybe_wait_prefetch_done_impl,
|
||||
fake_impl=_maybe_wait_prefetch_done_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel",
|
||||
op_func=_maybe_all_reduce_tensor_model_parallel_impl,
|
||||
fake_impl=lambda x: x,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
@@ -20,6 +20,7 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
|
||||
@@ -37,34 +38,39 @@ def _rope_forward_oot(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
is_neox_style_override: Optional[bool] = None,
|
||||
is_neox_style: bool,
|
||||
offsets: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
query_shape, key_shape = query.shape, key.shape
|
||||
if self.cos_sin_cache.device != query.device:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
|
||||
if self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
||||
neox_style = self.is_neox_style
|
||||
if is_neox_style_override is not None:
|
||||
neox_style = is_neox_style_override
|
||||
# adopt custom kernel path for rotary_embedding
|
||||
if _custom_rotary_embedding_enabled(query, neox_style,
|
||||
if _custom_rotary_embedding_enabled(query, is_neox_style,
|
||||
self.head_size) and not is_310p():
|
||||
query, key = torch.ops._C.rotary_embedding(
|
||||
query, key = torch.ops._C_ascend.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
neox_style,
|
||||
is_neox_style,
|
||||
)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
if offsets is not None:
|
||||
raise NotImplementedError(
|
||||
"Batched rotary embedding is currently not supported on NPU.")
|
||||
else:
|
||||
if self.rotary_dim < self.head_size:
|
||||
if self.cos is not None and \
|
||||
self.sin is not None:
|
||||
# If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.
|
||||
# This method requires head_size and rotary_dim equal 128 and neox_style is True
|
||||
query = query.contiguous().view(1, query.shape[0], -1,
|
||||
self.head_size)
|
||||
key = key.contiguous().view(1, key.shape[0], -1, self.head_size)
|
||||
torch_npu.npu_apply_rotary_pos_emb(query, key, self.cos, self.sin)
|
||||
elif self.rotary_dim < self.head_size:
|
||||
num_tokens = query.shape[0]
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
@@ -80,25 +86,26 @@ def _rope_forward_oot(
|
||||
k_rot,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
neox_style,
|
||||
is_neox_style,
|
||||
)
|
||||
q_rot = q_rot.view(num_tokens, -1, self.rotary_dim)
|
||||
k_rot = k_rot.view(num_tokens, -1, self.rotary_dim)
|
||||
q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
|
||||
k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
|
||||
return q, k
|
||||
# TODO: Remove the contiguous in the future.
|
||||
query = query.contiguous().view(query.shape[0], -1)
|
||||
key = key.contiguous().view(key.shape[0], -1)
|
||||
torch_npu._npu_rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
neox_style,
|
||||
)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
else:
|
||||
# TODO: Remove the contiguous in the future.
|
||||
query = query.contiguous().view(query.shape[0], -1)
|
||||
key = key.contiguous().view(key.shape[0], -1)
|
||||
torch_npu._npu_rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
is_neox_style,
|
||||
)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
|
||||
|
||||
class AscendRotaryEmbedding(RotaryEmbedding):
|
||||
@@ -112,6 +119,8 @@ class AscendRotaryEmbedding(RotaryEmbedding):
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
self.cos = None
|
||||
self.sin = None
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style, dtype)
|
||||
|
||||
@@ -123,14 +132,25 @@ class AscendRotaryEmbedding(RotaryEmbedding):
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
is_neox_style_override: Optional[bool] = None,
|
||||
):
|
||||
return _rope_forward_oot(
|
||||
self,
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
offsets,
|
||||
is_neox_style_override,
|
||||
)
|
||||
is_neox_style = self.is_neox_style
|
||||
if is_neox_style_override is not None:
|
||||
is_neox_style = is_neox_style_override
|
||||
forward_context = get_forward_context()
|
||||
is_first_layer = forward_context.is_first_layer
|
||||
# Generate cos and sin outside layers to avoid repeated calculation.
|
||||
if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[
|
||||
-1] == 128:
|
||||
if is_first_layer:
|
||||
cos_sin = self.cos_sin_cache.index_select(0, positions)
|
||||
last_dim = cos_sin.size()[-1]
|
||||
cos, sin = cos_sin.reshape(-1, 2, last_dim // 2).repeat(
|
||||
1, 1, 2).chunk(2, dim=-2)
|
||||
# BSNH
|
||||
self.cos = cos.view(1, -1, 1, last_dim).contiguous()
|
||||
self.sin = sin.view(1, -1, 1, last_dim).contiguous()
|
||||
forward_context.is_first_layer = False
|
||||
return _rope_forward_oot(self, positions, query, key, is_neox_style,
|
||||
offsets)
|
||||
|
||||
|
||||
class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
@@ -168,8 +188,10 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
super(DeepseekScalingRotaryEmbedding,
|
||||
self).__init__(head_size, rotary_dim, max_position_embeddings,
|
||||
base, is_neox_style, dtype)
|
||||
self.max_seq_len = max_position_embeddings
|
||||
self._set_cos_sin_cache(seq_len=max_position_embeddings,
|
||||
|
||||
# NOTE: For ascend friendly computing, reorder sin and cos cache
|
||||
self.max_seq_len = math.ceil(max_position_embeddings * scaling_factor)
|
||||
self._set_cos_sin_cache(self.max_seq_len,
|
||||
device=NPUPlatform.device_type,
|
||||
dtype=dtype)
|
||||
|
||||
@@ -275,8 +297,7 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
def _set_cos_sin_cache(self, max_seq_len, device, dtype):
|
||||
dim = self.rotary_dim
|
||||
|
||||
freq_extra = 1.0 / (self.base**(
|
||||
@@ -297,9 +318,7 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
inv_freq_mask) + freq_extra * inv_freq_mask
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
t = torch.arange(seq_len * self.scaling_factor,
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
t = torch.arange(max_seq_len, device=device, dtype=torch.float32)
|
||||
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
|
||||
@@ -317,16 +336,13 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
max_seq_len: Optional[int] = None):
|
||||
if max_seq_len is not None and max_seq_len > self.max_seq_len:
|
||||
self._set_cos_sin_cache(max_seq_len, query.device, query.dtype)
|
||||
offsets: Optional[torch.Tensor] = None):
|
||||
if len(key.shape) == 2:
|
||||
key = key[:, None, :]
|
||||
# Note: we implement the non neox_style method with shuffle the last dim and neox style
|
||||
# calculation method which is also more compute friendly to the ascend machine
|
||||
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py
|
||||
neox_style = True
|
||||
is_neox_style = True
|
||||
if self.is_neox_style is False:
|
||||
b, h_q, d = query.shape
|
||||
query = query.view(b, h_q, d // 2,
|
||||
@@ -334,6 +350,6 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
b, h_k, d = key.shape
|
||||
key = key.view(b, h_k, d // 2, 2).transpose(3,
|
||||
2).reshape(b, h_k, d)
|
||||
q_pe, k_pe = _rope_forward_oot(self, positions, query, key, offsets,
|
||||
neox_style)
|
||||
q_pe, k_pe = _rope_forward_oot(self, positions, query, key,
|
||||
is_neox_style, offsets)
|
||||
return q_pe, k_pe
|
||||
|
||||
384
vllm_ascend/ops/sigmoid_gating.py
Normal file
384
vllm_ascend/ops/sigmoid_gating.py
Normal file
@@ -0,0 +1,384 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
# mypy: ignore-errors
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, tldevice, triton
|
||||
|
||||
if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
|
||||
div = tldevice.fast_dividef
|
||||
exp = tldevice.fast_expf
|
||||
log = tldevice.fast_logf
|
||||
log2 = tldevice.fast_log2f
|
||||
else:
|
||||
|
||||
@triton.jit
|
||||
def div_normal(x, y):
|
||||
return x / y
|
||||
|
||||
div = div_normal
|
||||
exp = tl.exp
|
||||
log = tl.log
|
||||
log2 = tl.log2
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'USE_INITIAL_STATE':
|
||||
lambda args: args['h0'] is not None,
|
||||
'IS_VARLEN':
|
||||
lambda args: args['cu_seqlens'] is not None,
|
||||
"IS_CONTINUOUS_BATCHING":
|
||||
lambda args: args['ssm_state_indices'] is not None,
|
||||
"IS_SPEC_DECODING":
|
||||
lambda args: args['num_accepted_tokens'] is not None,
|
||||
})
|
||||
@triton.jit(do_not_specialize=['N', 'T'])
|
||||
def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
o,
|
||||
h0,
|
||||
ht,
|
||||
cu_seqlens,
|
||||
ssm_state_indices,
|
||||
num_accepted_tokens,
|
||||
scale,
|
||||
N: tl.constexpr, # num of sequences
|
||||
T: tl.constexpr, # num of tokens
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
HV: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
stride_init_state_token: tl.constexpr,
|
||||
stride_final_state_token: tl.constexpr,
|
||||
stride_indices_seq: tl.constexpr,
|
||||
stride_indices_tok: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
|
||||
IS_BETA_HEADWISE: tl.
|
||||
constexpr, # whether beta is headwise vector or scalar,
|
||||
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
||||
IS_SPEC_DECODING: tl.constexpr,
|
||||
):
|
||||
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_n, i_hv = i_nh // HV, i_nh % HV
|
||||
i_h = i_hv // (HV // H)
|
||||
if IS_VARLEN:
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
|
||||
all = T
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_n * T, i_n * T + T
|
||||
all = B * T
|
||||
|
||||
if T == 0:
|
||||
# no tokens to process for this sequence
|
||||
return
|
||||
|
||||
o_k = i_k * BK + tl.arange(0, BK)
|
||||
o_v = i_v * BV + tl.arange(0, BV)
|
||||
|
||||
mask_k = o_k < K
|
||||
mask_v = o_v < V
|
||||
mask_h = mask_k[:, None] & mask_v[None, :]
|
||||
|
||||
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
if USE_INITIAL_STATE:
|
||||
if IS_CONTINUOUS_BATCHING:
|
||||
if IS_SPEC_DECODING:
|
||||
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
|
||||
else:
|
||||
i_t = 0
|
||||
p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq +
|
||||
i_t).to(tl.int64) * stride_init_state_token
|
||||
else:
|
||||
p_h0 = h0 + bos * HV * K * V
|
||||
p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
||||
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
||||
|
||||
for i_t in range(0, T):
|
||||
p_q = q + (bos * H + i_h) * K + o_k + H * K * i_t
|
||||
p_k = k + (bos * H + i_h) * K + o_k + H * K * i_t
|
||||
p_v = v + (bos * HV + i_hv) * V + o_v + HV * V * i_t
|
||||
if IS_BETA_HEADWISE:
|
||||
p_beta = beta + (bos * HV + i_hv) * V + o_v + HV * V * i_t
|
||||
else:
|
||||
p_beta = beta + bos * HV + i_hv + HV * i_t
|
||||
p_g = g + bos * HV + i_hv + HV * i_t
|
||||
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + HV * V * i_t
|
||||
|
||||
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
|
||||
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
|
||||
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
|
||||
b_g = tl.load(p_g).to(tl.float32)
|
||||
|
||||
if USE_QK_L2NORM_IN_KERNEL:
|
||||
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
|
||||
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
|
||||
b_q = b_q * scale
|
||||
# [BK, BV]
|
||||
# b_h *= tl.exp(b_g)
|
||||
b_h *= exp(b_g)
|
||||
# [BV]
|
||||
b_v -= tl.sum(b_h * b_k[:, None], 0)
|
||||
if IS_BETA_HEADWISE:
|
||||
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
|
||||
else:
|
||||
b_beta = tl.load(p_beta).to(tl.float32)
|
||||
b_v *= b_beta
|
||||
# [BK, BV]
|
||||
b_h += b_k[:, None] * b_v[None, :]
|
||||
# [BV]
|
||||
b_o = tl.sum(b_h * b_q[:, None], 0)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
||||
|
||||
# keep the states for multi-query tokens
|
||||
if INPLACE_FINAL_STATE:
|
||||
p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq +
|
||||
i_t).to(tl.int64) * stride_final_state_token
|
||||
else:
|
||||
p_ht = ht + (bos + i_t) * stride_final_state_token
|
||||
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
||||
|
||||
|
||||
def fused_recurrent_gated_delta_rule_fwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
ssm_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, H, K, V = *k.shape, v.shape[-1]
|
||||
HV = v.shape[2]
|
||||
N = B if cu_seqlens is None else len(cu_seqlens) - 1
|
||||
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
assert NK == 1, "NK > 1 is not supported yet"
|
||||
num_stages = 3
|
||||
num_warps = 1
|
||||
|
||||
o = q.new_empty(NK, *v.shape)
|
||||
if inplace_final_state:
|
||||
final_state = initial_state
|
||||
else:
|
||||
final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype)
|
||||
|
||||
stride_init_state_token = initial_state.stride(0)
|
||||
stride_final_state_token = final_state.stride(0)
|
||||
|
||||
if ssm_state_indices is None:
|
||||
stride_indices_seq, stride_indices_tok = 1, 1
|
||||
elif ssm_state_indices.ndim == 1:
|
||||
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
|
||||
else:
|
||||
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
|
||||
|
||||
# print("N: ", N)
|
||||
# print("T: ", T)
|
||||
# print("B: ", B)
|
||||
# print("H: ", H)
|
||||
# print("HV: ", HV)
|
||||
# print("K: ", K)
|
||||
# print("V: ", V)
|
||||
# print("BK: ", BK)
|
||||
# print("BV: ", BV)
|
||||
|
||||
grid = (NK, NV, N * HV)
|
||||
fused_recurrent_gated_delta_rule_fwd_kernel[grid](
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
o=o,
|
||||
h0=initial_state,
|
||||
ht=final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
ssm_state_indices=ssm_state_indices,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
scale=scale,
|
||||
N=N,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
HV=HV,
|
||||
K=K,
|
||||
V=V,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
stride_init_state_token=stride_init_state_token,
|
||||
stride_final_state_token=stride_final_state_token,
|
||||
stride_indices_seq=stride_indices_seq,
|
||||
stride_indices_tok=stride_indices_tok,
|
||||
IS_BETA_HEADWISE=beta.ndim == v.ndim,
|
||||
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
|
||||
INPLACE_FINAL_STATE=inplace_final_state,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
o = o.squeeze(0)
|
||||
return o, final_state
|
||||
|
||||
|
||||
class FusedRecurrentFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
ssm_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False):
|
||||
o, final_state = fused_recurrent_gated_delta_rule_fwd(
|
||||
q=q.contiguous(),
|
||||
k=k.contiguous(),
|
||||
v=v.contiguous(),
|
||||
g=g.contiguous(),
|
||||
beta=beta.contiguous(),
|
||||
scale=scale,
|
||||
initial_state=initial_state,
|
||||
inplace_final_state=inplace_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
ssm_state_indices=ssm_state_indices,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
|
||||
)
|
||||
|
||||
return o, final_state
|
||||
|
||||
|
||||
def fused_recurrent_gated_delta_rule(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor = None,
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
ssm_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
q (torch.Tensor):
|
||||
queries of shape `[B, T, H, K]`.
|
||||
k (torch.Tensor):
|
||||
keys of shape `[B, T, H, K]`.
|
||||
v (torch.Tensor):
|
||||
values of shape `[B, T, HV, V]`.
|
||||
GVA is applied if `HV > H`.
|
||||
g (torch.Tensor):
|
||||
g (decays) of shape `[B, T, HV]`.
|
||||
beta (torch.Tensor):
|
||||
betas of shape `[B, T, HV]`.
|
||||
scale (Optional[int]):
|
||||
Scale factor for the RetNet attention scores.
|
||||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||
initial_state (Optional[torch.Tensor]):
|
||||
Initial state of shape `[N, HV, K, V]` for `N` input sequences.
|
||||
For equal-length input sequences, `N` equals the batch size `B`.
|
||||
Default: `None`.
|
||||
inplace_final_state: bool:
|
||||
Whether to store the final state in-place to save memory.
|
||||
Default: `True`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
||||
consistent with the FlashAttention API.
|
||||
ssm_state_indices (Optional[torch.Tensor]):
|
||||
Indices to map the input sequences to the initial/final states.
|
||||
num_accepted_tokens (Optional[torch.Tensor]):
|
||||
Number of accepted tokens for each sequence during decoding.
|
||||
Returns:
|
||||
o (torch.Tensor):
|
||||
Outputs of shape `[B, T, HV, V]`.
|
||||
final_state (torch.Tensor):
|
||||
Final state of shape `[N, HV, K, V]`.
|
||||
Examples::
|
||||
>>> import torch
|
||||
>>> import torch.nn.functional as F
|
||||
>>> from einops import rearrange
|
||||
>>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
|
||||
# inputs with equal lengths
|
||||
>>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512
|
||||
>>> q = torch.randn(B, T, H, K, device='cuda')
|
||||
>>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
|
||||
>>> v = torch.randn(B, T, HV, V, device='cuda')
|
||||
>>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
|
||||
>>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
|
||||
>>> h0 = torch.randn(B, HV, K, V, device='cuda')
|
||||
>>> o, ht = fused_gated_recurrent_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
)
|
||||
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
||||
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
|
||||
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
||||
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
cu_seqlens=cu_seqlens
|
||||
)
|
||||
"""
|
||||
if cu_seqlens is not None and q.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
||||
f"Please flatten variable-length inputs before processing.")
|
||||
if scale is None:
|
||||
scale = k.shape[-1]**-0.5
|
||||
else:
|
||||
assert scale > 0, "scale must be positive"
|
||||
if beta is None:
|
||||
beta = torch.ones_like(q[..., 0])
|
||||
o, final_state = FusedRecurrentFunction.apply(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
scale,
|
||||
initial_state,
|
||||
inplace_final_state,
|
||||
cu_seqlens,
|
||||
ssm_state_indices,
|
||||
num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel,
|
||||
)
|
||||
return o, final_state
|
||||
@@ -97,6 +97,7 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding):
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
# Divide the weight matrix along the vocaburaly dimension.
|
||||
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
|
||||
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
|
||||
@@ -252,3 +253,16 @@ class AscendLogitsProcessor(LogitsProcessor):
|
||||
logits = logits[..., :self.org_vocab_size]
|
||||
|
||||
return logits
|
||||
|
||||
def forward(
|
||||
self,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
hidden_states: torch.Tensor,
|
||||
# keep this for version compatibility
|
||||
sampling_metadata=None, # type: ignore
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return LogitsProcessor.forward(self,
|
||||
lm_head,
|
||||
hidden_states,
|
||||
embedding_bias=embedding_bias)
|
||||
|
||||
@@ -46,6 +46,27 @@
|
||||
# Need a PR to vllm to support get port from environment.
|
||||
# Future Plan:
|
||||
# Remove those patch when vllm merged them
|
||||
# 2. `torch.distributed.all_reduce`, `torch.distributed.broadcast`
|
||||
# Why:
|
||||
# tensor alignment for 310p
|
||||
# How:
|
||||
# rewrite all_reduce and broadcast in torch.distributed
|
||||
# Related PR (if no, explain why):
|
||||
# No, not ready yet.
|
||||
# Future Plan:
|
||||
# Find a better way to support tensor alignment for 310p without this patch.
|
||||
#
|
||||
# ** File: platform/patch_common/patch_multimodal_merge.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.model_executor.models.utils._merge_multimodal_embeddings`
|
||||
# Why:
|
||||
# '_merge_multimodal_embeddings' func of vllm is incompatible with Ascend.
|
||||
# How:
|
||||
# Replace with CPU operation that can be executed asynchronously.
|
||||
# Related PR (if no, explain why):
|
||||
# This is a bug by Ascend only. It can' be fixed in vLLM.
|
||||
# Future Plan:
|
||||
# Identify this pattern in torch-npu and remove this patch.
|
||||
#
|
||||
# * Worker Patch:
|
||||
# ===============
|
||||
@@ -86,19 +107,15 @@
|
||||
# - https://github.com/vllm-project/vllm/pull/21591
|
||||
# Future Plan:
|
||||
# Revert it when vLLM merge #21591 and release new version
|
||||
# ** File: worker/patch_common/patch_linear.py **
|
||||
# ** File: worker/patch_common/patch_logits.py **
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.model_executor.layers.linear.RowParallelLinear`
|
||||
# 1. `vllm._custom_ops.apply_repetition_penalties`
|
||||
# Why:
|
||||
# We need to fuse matmul and allreuce in `RowParallelLinear`
|
||||
# to improve performance.
|
||||
# apply_repetition_penalties in vLLM use tensor.is_cuda to check if tensor is on cuda. But the value is always True
|
||||
# on ascend, thus we need to patch apply_repetition_penalties.
|
||||
# How:
|
||||
# Create a new class `AscendRowParallelLinear` that inherits from `RowParallelLinear`.
|
||||
# In this class, we override the `forward` method to use
|
||||
# torch_npu.npu_mm_all_reduce_base to replace matmul and allreduce.
|
||||
# Remove the related cuda check in apply_repetition_penalties.
|
||||
# Related PR (if no, explain why):
|
||||
# - https://github.com/vllm-project/vllm-ascend/pull/1926
|
||||
# - this is a bug by Ascend only. It can' be fixed in vLLM.
|
||||
# Future Plan:
|
||||
# Validate more models in all kinds of scenario,
|
||||
# if performance is always improved, we can enable this patch by default and remove the env
|
||||
# variable `VLLM_ASCEND_ENABLE_FUSE_MATMUL_ALLREDUCE` in the future.
|
||||
# Fix this bug in torch-npu, bump torch-npu version and remove this patch.
|
||||
|
||||
@@ -15,4 +15,10 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import vllm_ascend.patch.platform.patch_common.patch_config # noqa
|
||||
import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa
|
||||
import vllm_ascend.patch.platform.patch_common.patch_mamba_config # noqa
|
||||
import vllm_ascend.patch.platform.patch_common.patch_multimodal_merge # noqa
|
||||
import vllm_ascend.patch.platform.patch_common.patch_transformers_utils # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_attentionspec # noqa
|
||||
|
||||
313
vllm_ascend/patch/platform/patch_common/patch_config.py
Normal file
313
vllm_ascend/patch/platform/patch_common/patch_config.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import ast
|
||||
|
||||
import vllm.envs as envs
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.speculative import SpeculativeConfig
|
||||
from vllm.logger import logger
|
||||
|
||||
|
||||
# mypy: ignore-errors
|
||||
@property
|
||||
def is_deepseek_mla(self: ModelConfig):
|
||||
if not hasattr(self.hf_text_config, "model_type"):
|
||||
return False
|
||||
elif self.hf_text_config.model_type in \
|
||||
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp',
|
||||
'kimi_k2', 'longcat_flash', 'deepseek_v32'):
|
||||
return self.hf_text_config.kv_lora_rank is not None
|
||||
elif self.hf_text_config.model_type == 'eagle':
|
||||
# if the model is an EAGLE module, check for the
|
||||
# underlying architecture
|
||||
return self.hf_text_config.model.model_type in \
|
||||
('deepseek_v2', 'deepseek_v3', 'deepseek_v32') \
|
||||
and self.hf_text_config.kv_lora_rank is not None
|
||||
return False
|
||||
|
||||
|
||||
@staticmethod
|
||||
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
|
||||
if hf_config.model_type in ("deepseek_v3", "deepseek_v32"):
|
||||
hf_config.model_type = "deepseek_mtp"
|
||||
if hf_config.model_type == "deepseek_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update({
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["DeepSeekMTPModel"]
|
||||
})
|
||||
|
||||
if hf_config.architectures[0] == "MiMoForCausalLM":
|
||||
hf_config.model_type = "mimo_mtp"
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update({
|
||||
"num_hidden_layers": 0,
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["MiMoMTPModel"]
|
||||
})
|
||||
|
||||
if hf_config.architectures[0] == "Glm4MoeForCausalLM":
|
||||
hf_config.model_type = "glm4_moe_mtp"
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update({
|
||||
"num_hidden_layers": 0,
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["Glm4MoeMTPModel"]
|
||||
})
|
||||
|
||||
if hf_config.model_type == "ernie4_5_moe":
|
||||
hf_config.model_type = "ernie_mtp"
|
||||
if hf_config.model_type == "ernie_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update({
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["ErnieMTPModel"]
|
||||
})
|
||||
|
||||
if hf_config.model_type == "qwen3_next":
|
||||
hf_config.model_type = "qwen3_next_mtp"
|
||||
if hf_config.model_type == "qwen3_next_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update({
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["Qwen3NextMTP"]
|
||||
})
|
||||
if hf_config.model_type == "longcat_flash":
|
||||
hf_config.model_type = "longcat_flash_mtp"
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
|
||||
hf_config.update({
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["LongCatFlashMTPModel"]
|
||||
})
|
||||
|
||||
return hf_config
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
# Note: "method" is a new parameter that helps to extend the
|
||||
# configuration of non-model-based proposers, and the "model" parameter
|
||||
# will be used to set the draft model, eagle head, or additional weight
|
||||
# when needed. If users do not specify "method", the speculative method
|
||||
# will be detected automatically if possible. If the speculative method
|
||||
# can not be detected, it will be considered as the "draft_model" by
|
||||
# default.
|
||||
|
||||
if self.model is None and self.num_speculative_tokens is not None:
|
||||
# TODO(Shangming): Refactor mtp configuration logic when supporting
|
||||
if (self.target_model_config
|
||||
and self.target_model_config.hf_text_config.model_type
|
||||
in ("deepseek_v3", "deepseek_v32", "mimo", "ernie4_5_moe",
|
||||
"qwen3_next")):
|
||||
# use the draft model from the same model:
|
||||
self.model = self.target_model_config.model
|
||||
# Align the quantization of draft model for cases such as
|
||||
# --quantization fp8 with a bf16 checkpoint.
|
||||
if not self.quantization:
|
||||
self.quantization = self.target_model_config.quantization
|
||||
elif self.method in ("ngram", "[ngram]"):
|
||||
self.model = "ngram"
|
||||
else:
|
||||
raise ValueError("num_speculative_tokens was provided but without "
|
||||
"speculative model.")
|
||||
|
||||
# Automatically configure the method for ngram when "model" is used
|
||||
# instead of "method"
|
||||
if self.method is None and (self.model is not None
|
||||
and self.model in ("ngram", "[ngram]")):
|
||||
self.method = "ngram"
|
||||
|
||||
if self.method in ("ngram", "[ngram]"):
|
||||
# Unified to "ngram" internally
|
||||
self.method = "ngram"
|
||||
# Set default values if not provided
|
||||
if (self.prompt_lookup_min is None and self.prompt_lookup_max is None):
|
||||
# TODO(woosuk): Tune these values. They are arbitrarily chosen.
|
||||
self.prompt_lookup_min = 5
|
||||
self.prompt_lookup_max = 5
|
||||
elif self.prompt_lookup_min is None:
|
||||
assert self.prompt_lookup_max is not None
|
||||
self.prompt_lookup_min = self.prompt_lookup_max
|
||||
elif self.prompt_lookup_max is None:
|
||||
assert self.prompt_lookup_min is not None
|
||||
self.prompt_lookup_max = self.prompt_lookup_min
|
||||
|
||||
# Validate values
|
||||
if self.prompt_lookup_min < 1:
|
||||
raise ValueError(
|
||||
f"prompt_lookup_min={self.prompt_lookup_min} must be > 0")
|
||||
if self.prompt_lookup_max < 1:
|
||||
raise ValueError(
|
||||
f"prompt_lookup_max={self.prompt_lookup_max} must be > 0")
|
||||
if self.prompt_lookup_min > self.prompt_lookup_max:
|
||||
raise ValueError(
|
||||
f"prompt_lookup_min={self.prompt_lookup_min} must "
|
||||
f"be <= prompt_lookup_max={self.prompt_lookup_max}")
|
||||
|
||||
# TODO: current we still need extract vocab_size from target model
|
||||
# config, in future, we may try refactor it out, and set
|
||||
# draft related config as None here.
|
||||
self.draft_model_config = self.target_model_config
|
||||
self.draft_parallel_config = self.target_parallel_config
|
||||
else:
|
||||
self.prompt_lookup_max = 0
|
||||
self.prompt_lookup_min = 0
|
||||
|
||||
if self.model is not None:
|
||||
# TODO: Move this import to the top once `ModelConfig`
|
||||
# lives in `vllm.config.model`.
|
||||
from vllm.config import ModelConfig
|
||||
self.draft_model_config = ModelConfig(
|
||||
model=self.model,
|
||||
runner="draft",
|
||||
tokenizer=self.target_model_config.tokenizer,
|
||||
tokenizer_mode=self.target_model_config.tokenizer_mode,
|
||||
trust_remote_code=self.target_model_config.trust_remote_code,
|
||||
allowed_local_media_path=self.target_model_config.
|
||||
allowed_local_media_path,
|
||||
allowed_media_domains=self.target_model_config.
|
||||
allowed_media_domains,
|
||||
dtype=self.target_model_config.dtype,
|
||||
seed=self.target_model_config.seed,
|
||||
revision=self.revision,
|
||||
code_revision=self.code_revision,
|
||||
tokenizer_revision=self.target_model_config.tokenizer_revision,
|
||||
spec_target_max_model_len=self.target_model_config.
|
||||
max_model_len,
|
||||
quantization=self.quantization,
|
||||
enforce_eager=self.target_model_config.enforce_eager,
|
||||
max_logprobs=self.target_model_config.max_logprobs,
|
||||
hf_overrides=SpeculativeConfig.hf_config_override,
|
||||
)
|
||||
|
||||
# Automatically detect the method
|
||||
if self.method in ('eagle', 'eagle3'):
|
||||
pass
|
||||
# examples:
|
||||
# yuhuili/EAGLE-LLaMA3-Instruct-8B
|
||||
# yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
|
||||
# AngelSlim/Qwen3-8B_eagle3
|
||||
elif "eagle-" in self.draft_model_config.model.lower():
|
||||
self.method = "eagle"
|
||||
elif "eagle3" in self.draft_model_config.model.lower():
|
||||
self.method = "eagle3"
|
||||
elif self.draft_model_config.hf_config.model_type == "medusa":
|
||||
self.method = "medusa"
|
||||
elif (self.draft_model_config.hf_config.model_type ==
|
||||
"mlp_speculator"):
|
||||
self.method = "mlp_speculator"
|
||||
elif (self.draft_model_config.hf_config.model_type
|
||||
in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")):
|
||||
self.method = "deepseek_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"All Deepseek MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
)
|
||||
elif (self.draft_model_config.hf_config.model_type == "ernie_mtp"):
|
||||
self.method = "ernie_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"All Ernie MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
)
|
||||
elif (self.draft_model_config.hf_config.model_type ==
|
||||
"qwen3_next_mtp"):
|
||||
self.method = "qwen3_next_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"All Qwen3Next MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
)
|
||||
elif (self.draft_model_config.hf_config.model_type
|
||||
in ("longcat_flash_mtp")):
|
||||
self.method = "longcat_flash_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"LongCat MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
)
|
||||
else:
|
||||
self.method = "draft_model"
|
||||
raise NotImplementedError(
|
||||
"Speculative decoding with draft model is not "
|
||||
"supported yet. Please consider using other "
|
||||
"speculative decoding methods such as ngram, medusa, "
|
||||
"eagle, or deepseek_mtp.")
|
||||
|
||||
# Replace hf_config for EAGLE draft_model
|
||||
if self.method in ("eagle", "eagle3"):
|
||||
if self.enable_chunked_prefill and not envs.VLLM_USE_V1:
|
||||
raise ValueError(
|
||||
"Chunked prefill and EAGLE are not compatible "
|
||||
"when using V0.")
|
||||
|
||||
from vllm.transformers_utils.configs import SpeculatorsConfig
|
||||
from vllm.transformers_utils.configs.eagle import EAGLEConfig
|
||||
|
||||
if isinstance(self.draft_model_config.hf_config,
|
||||
(EAGLEConfig, SpeculatorsConfig)):
|
||||
pass
|
||||
else:
|
||||
eagle_config = EAGLEConfig(
|
||||
self.draft_model_config.hf_config,
|
||||
method=self.method,
|
||||
model_type="eagle")
|
||||
self.draft_model_config.hf_config = eagle_config
|
||||
|
||||
if (self.num_speculative_tokens is not None
|
||||
and hasattr(self.draft_model_config.hf_config,
|
||||
"num_lookahead_tokens")):
|
||||
self.draft_model_config.hf_config.num_lookahead_tokens = \
|
||||
self.num_speculative_tokens
|
||||
|
||||
n_predict = getattr(self.draft_model_config.hf_config, "n_predict",
|
||||
None)
|
||||
if n_predict is not None:
|
||||
if self.num_speculative_tokens is None:
|
||||
# Default to max value defined in draft model config.
|
||||
self.num_speculative_tokens = n_predict
|
||||
elif self.num_speculative_tokens > n_predict and \
|
||||
self.num_speculative_tokens % n_predict != 0:
|
||||
# Ensure divisibility for MTP module reuse.
|
||||
raise ValueError(
|
||||
f"num_speculative_tokens:{self.num_speculative_tokens}"
|
||||
f" must be divisible by {n_predict=}")
|
||||
|
||||
if self.speculative_token_tree is None:
|
||||
# Generate chain of tokens.
|
||||
self.speculative_token_tree = str([
|
||||
(i + 1) * (0, ) for i in range(self.num_speculative_tokens)
|
||||
])
|
||||
else:
|
||||
# Sort the token tree breadth-first.
|
||||
tree_choices = ast.literal_eval(self.speculative_token_tree)
|
||||
self.speculative_token_tree = str(
|
||||
sorted(tree_choices, key=lambda t: (len(t), t)))
|
||||
|
||||
self.draft_tensor_parallel_size = \
|
||||
SpeculativeConfig._verify_and_get_draft_tp(
|
||||
self.target_parallel_config,
|
||||
self.draft_tensor_parallel_size,
|
||||
self.draft_model_config.hf_config
|
||||
)
|
||||
|
||||
self.draft_model_config.max_model_len = (
|
||||
SpeculativeConfig._maybe_override_draft_max_model_len(
|
||||
self.max_model_len,
|
||||
self.draft_model_config.max_model_len,
|
||||
self.target_model_config.max_model_len,
|
||||
))
|
||||
|
||||
self.draft_parallel_config = (
|
||||
SpeculativeConfig.create_draft_parallel_config(
|
||||
self.target_parallel_config,
|
||||
self.draft_tensor_parallel_size))
|
||||
|
||||
|
||||
ModelConfig.is_deepseek_mla = is_deepseek_mla
|
||||
SpeculativeConfig.__post_init__ = __post_init__
|
||||
SpeculativeConfig.hf_config_override = hf_config_override
|
||||
100
vllm_ascend/patch/platform/patch_common/patch_mamba_config.py
Normal file
100
vllm_ascend/patch/platform/patch_common/patch_mamba_config.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# mypy: ignore-errors
|
||||
import vllm.model_executor.models.config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.models.config import MambaModelConfig
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
|
||||
|
||||
@classmethod
|
||||
def verify_and_update_config(cls, vllm_config) -> None:
|
||||
"""
|
||||
Ensure that page size of attention layers is greater than or
|
||||
equal to the mamba layers. If not, automatically set the attention
|
||||
block size to ensure that it is. If the attention page size is
|
||||
strictly greater than the mamba page size, we pad the mamba page size
|
||||
to make them equal.
|
||||
|
||||
Args:
|
||||
vllm_config: vLLM Config
|
||||
"""
|
||||
logger = init_logger(__name__)
|
||||
# Enable FULL_AND_PIECEWISE by default
|
||||
MambaModelConfig.verify_and_update_config(vllm_config)
|
||||
ascend_config = get_ascend_config()
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
if cache_config.cache_dtype == "auto":
|
||||
kv_cache_dtype = model_config.dtype
|
||||
else:
|
||||
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||
|
||||
# get attention page size (for 1 token)
|
||||
attn_page_size_1_token = FullAttentionSpec(
|
||||
block_size=1,
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
dtype=kv_cache_dtype,
|
||||
use_mla=model_config.use_mla or ascend_config.use_sfa).page_size_bytes
|
||||
|
||||
model_cls, _ = ModelRegistry.resolve_model_cls(
|
||||
model_config.architecture,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
# get mamba page size
|
||||
mamba_page_size = MambaSpec(
|
||||
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
|
||||
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
|
||||
block_size=model_config.max_model_len,
|
||||
).page_size_bytes
|
||||
|
||||
block_alignment_bytes = 64
|
||||
|
||||
# some attention backends (e.g. FA) only support setting
|
||||
# block size to multiple of 16, so let's suggest a value
|
||||
# that would work (note: FA is currently not compatible
|
||||
# with mamba layers, use FlashInfer instead).
|
||||
attn_block_size = block_alignment_bytes * cdiv(
|
||||
mamba_page_size, block_alignment_bytes * attn_page_size_1_token)
|
||||
|
||||
# override attention block size if either (a) the
|
||||
# user has not set it or (b) the user has set it
|
||||
# too small.
|
||||
if (cache_config.block_size is None
|
||||
or cache_config.block_size < attn_block_size):
|
||||
cache_config.block_size = attn_block_size
|
||||
logger.info(
|
||||
"Setting attention block size to %d tokens "
|
||||
"to ensure that attention page size is >= mamba page size.",
|
||||
attn_block_size)
|
||||
|
||||
# compute new attention page size
|
||||
attn_page_size = \
|
||||
cache_config.block_size * attn_page_size_1_token
|
||||
|
||||
assert attn_page_size >= mamba_page_size
|
||||
|
||||
if attn_page_size == mamba_page_size:
|
||||
# don't need to pad mamba page size
|
||||
return
|
||||
|
||||
# pad mamba page size to exactly match attention
|
||||
if (cache_config.mamba_page_size_padded is None
|
||||
or cache_config.mamba_page_size_padded != attn_page_size):
|
||||
cache_config.mamba_page_size_padded = (attn_page_size)
|
||||
mamba_padding_pct = 100 * (attn_page_size -
|
||||
mamba_page_size) / mamba_page_size
|
||||
logger.info(
|
||||
"Padding mamba page size by %.2f%% to ensure "
|
||||
"that mamba page size and attention page size are "
|
||||
"exactly equal.", mamba_padding_pct)
|
||||
|
||||
|
||||
vllm.model_executor.models.config.HybridAttentionMambaModelConfig.verify_and_update_config = verify_and_update_config
|
||||
@@ -0,0 +1,58 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
import torch
|
||||
import vllm
|
||||
from vllm.model_executor.models.utils import (_embedding_count_expression,
|
||||
_flatten_embeddings)
|
||||
from vllm.multimodal import NestedTensors
|
||||
|
||||
|
||||
def _merge_multimodal_embeddings(
|
||||
inputs_embeds: torch.Tensor,
|
||||
is_multimodal: torch.Tensor,
|
||||
multimodal_embeddings: NestedTensors,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
|
||||
positions in ``inputs_embeds`` corresponding to placeholder tokens in
|
||||
``input_ids``.
|
||||
|
||||
Note:
|
||||
This updates ``inputs_embeds`` in place.
|
||||
"""
|
||||
flattened = _flatten_embeddings(multimodal_embeddings)
|
||||
try:
|
||||
inputs_embeds[is_multimodal] = flattened
|
||||
except RuntimeError as e:
|
||||
num_expected_tokens = is_multimodal.sum().item()
|
||||
assert isinstance(num_expected_tokens, int)
|
||||
|
||||
if flattened.shape[0] != num_expected_tokens:
|
||||
expr = _embedding_count_expression(multimodal_embeddings)
|
||||
raise ValueError(
|
||||
f"Attempted to assign {expr} = {flattened.shape[0]} "
|
||||
f"multimodal tokens to {num_expected_tokens} placeholders"
|
||||
) from e
|
||||
else:
|
||||
raise ValueError("Error during masked scatter operation") from e
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
vllm.model_executor.models.utils._merge_multimodal_embeddings = _merge_multimodal_embeddings
|
||||
@@ -0,0 +1,200 @@
|
||||
import vllm.transformers_utils.configs
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
from vllm.transformers_utils import config
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class DeepseekV3Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the DeepSeek-V3.
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 129280):
|
||||
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`DeepseekV3Model`]
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 11008):
|
||||
Dimension of the MLP representations.
|
||||
moe_intermediate_size (`int`, *optional*, defaults to 1407):
|
||||
Dimension of the MoE representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_nextn_predict_layers (`int`, *optional*, defaults to 1):
|
||||
Number of nextn predict layers in the DeepSeekV3 Model.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
n_shared_experts (`int`, *optional*, defaults to None):
|
||||
Number of shared experts, None means dense model.
|
||||
n_routed_experts (`int`, *optional*, defaults to None):
|
||||
Number of routed experts, None means dense model.
|
||||
routed_scaling_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor or routed experts.
|
||||
topk_method (`str`, *optional*, defaults to `gready`):
|
||||
Topk method used in routed gate.
|
||||
n_group (`int`, *optional*, defaults to None):
|
||||
Number of groups for routed experts.
|
||||
topk_group (`int`, *optional*, defaults to None):
|
||||
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
|
||||
num_experts_per_tok (`int`, *optional*, defaults to None):
|
||||
Number of selected experts, None means dense model.
|
||||
moe_layer_freq (`int`, *optional*, defaults to 1):
|
||||
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
|
||||
first_k_dense_replace (`int`, *optional*, defaults to 0):
|
||||
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
|
||||
\--k dense layers--/
|
||||
norm_topk_prob (`bool`, *optional*, defaults to False):
|
||||
Whether to normalize the weights of the routed experts.
|
||||
scoring_func (`str`, *optional*, defaults to 'softmax'):
|
||||
Method of computing expert weights.
|
||||
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
|
||||
Auxiliary loss weight coefficient.
|
||||
seq_aux = (`bool`, *optional*, defaults to True):
|
||||
Whether to compute the auxiliary loss for each individual sample.
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*):
|
||||
Padding token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
Beginning of stream token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
End of stream token id.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
||||
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
||||
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
||||
`max_position_embeddings` to the expected new maximum.
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
```python
|
||||
>>> from transformers import DeepseekV3Model, DeepseekV3Config
|
||||
>>> # Initializing a Deepseek-V3 style configuration
|
||||
>>> configuration = DeepseekV3Config()
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "deepseek_v3"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=129280,
|
||||
hidden_size=7168,
|
||||
intermediate_size=18432,
|
||||
moe_intermediate_size=2048,
|
||||
num_hidden_layers=61,
|
||||
num_nextn_predict_layers=1,
|
||||
num_attention_heads=128,
|
||||
num_key_value_heads=128,
|
||||
n_shared_experts=1,
|
||||
n_routed_experts=256,
|
||||
ep_size=1,
|
||||
routed_scaling_factor=2.5,
|
||||
kv_lora_rank=512,
|
||||
q_lora_rank=1536,
|
||||
qk_rope_head_dim=64,
|
||||
v_head_dim=128,
|
||||
qk_nope_head_dim=128,
|
||||
topk_method='noaux_tc',
|
||||
n_group=8,
|
||||
topk_group=4,
|
||||
num_experts_per_tok=8,
|
||||
moe_layer_freq=1,
|
||||
first_k_dense_replace=3,
|
||||
norm_topk_prob=True,
|
||||
scoring_func='sigmoid',
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=4096,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=0,
|
||||
eos_token_id=1,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_nextn_predict_layers = num_nextn_predict_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.n_shared_experts = n_shared_experts
|
||||
self.n_routed_experts = n_routed_experts
|
||||
self.ep_size = ep_size
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.topk_method = topk_method
|
||||
self.n_group = n_group
|
||||
self.topk_group = topk_group
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.moe_layer_freq = moe_layer_freq
|
||||
self.first_k_dense_replace = first_k_dense_replace
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.scoring_func = scoring_func
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
vllm.transformers_utils.configs.__all__.append("DeepseekV3Config")
|
||||
vllm.transformers_utils.configs.DeepseekV3Config = DeepseekV3Config
|
||||
config._CONFIG_REGISTRY["deepseek_v32"] = "DeepseekV3Config"
|
||||
@@ -15,8 +15,18 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
if HAS_TRITON:
|
||||
import vllm_ascend.patch.worker.patch_common.patch_triton
|
||||
|
||||
# isort: off
|
||||
import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_attentionspec # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_attention_layer # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_linear # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_logits # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_lora_embedding # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_weight_loader # noqa
|
||||
|
||||
# TODO: revert me when triton import is fixed
|
||||
# import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
|
||||
|
||||
202
vllm_ascend/patch/worker/patch_common/patch_attention_layer.py
Normal file
202
vllm_ascend/patch/worker/patch_common/patch_attention_layer.py
Normal file
@@ -0,0 +1,202 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import vllm
|
||||
import vllm.envs as envs
|
||||
from torch import nn
|
||||
from vllm.attention import Attention, AttentionType, get_attn_backend
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.selector import backend_name_to_enum
|
||||
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.base_config import \
|
||||
QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
|
||||
class AscendAttention(Attention, nn.Module, AttentionLayerBase):
|
||||
"""Attention layer.
|
||||
|
||||
This class takes query, key, and value tensors as input. The input tensors
|
||||
can either contain prompt tokens or generation tokens.
|
||||
The class does the following:
|
||||
|
||||
1. Store the input key and value tensors in the KV cache.
|
||||
2. Perform (multi-head/multi-query/grouped-query) attention.
|
||||
3. Return the output tensor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
per_layer_sliding_window: Optional[int] = None,
|
||||
use_mla: bool = False,
|
||||
use_sfa: bool = False,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
attn_backend: Optional[type[AttentionBackend]] = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
"""
|
||||
The KV cache is stored inside this class and is accessed via
|
||||
`self.kv_cache`.
|
||||
"""
|
||||
nn.Module.__init__(self)
|
||||
AttentionLayerBase.__init__(self)
|
||||
|
||||
if per_layer_sliding_window is not None:
|
||||
# per-layer sliding window
|
||||
sliding_window = per_layer_sliding_window
|
||||
elif cache_config is not None:
|
||||
# model-level sliding window
|
||||
sliding_window = cache_config.sliding_window
|
||||
else:
|
||||
sliding_window = None
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
is_attention_free = cache_config.is_attention_free
|
||||
calculate_kv_scales = cache_config.calculate_kv_scales
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
is_attention_free = False
|
||||
calculate_kv_scales = False
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = num_heads
|
||||
assert num_heads % num_kv_heads == 0, \
|
||||
f"num_heads ({num_heads}) is not " \
|
||||
f"divisible by num_kv_heads ({num_kv_heads})"
|
||||
|
||||
# The default k/v_scale is set to 1.0. This is ignored
|
||||
# when kv-cache is not fp8, and should be used with
|
||||
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
|
||||
# expect the pre-quantized k/v_scale to be loaded along
|
||||
# with the model weights.
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.calculate_kv_scales = calculate_kv_scales
|
||||
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
# FlashAttn doesn't support quantizing the kv-cache only
|
||||
# but requires q to be quantized as well.
|
||||
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
|
||||
# We also keep q/k/v_scale on host (cpu) memory for attention
|
||||
# backends that require the scales to be on host instead of on device.
|
||||
# e.g. Flashinfer
|
||||
self._q_scale_float = 1.0
|
||||
self._k_scale_float = 1.0
|
||||
self._v_scale_float = 1.0
|
||||
|
||||
# The output scale on host memory. This should be the input scale of
|
||||
# the quant op after this attention layer.
|
||||
self._o_scale_float: Optional[float] = None
|
||||
|
||||
self.use_mla = use_mla
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.has_sink = extra_impl_args.get("sinks") is not None
|
||||
|
||||
quant_method = quant_config.get_quant_method(
|
||||
self, prefix=prefix) if quant_config else None
|
||||
if quant_method is not None and not isinstance(
|
||||
quant_method, UnquantizedLinearMethod):
|
||||
assert isinstance(quant_method, BaseKVCacheMethod)
|
||||
# TODO (mgoin): kv cache dtype should be specified in the FP8
|
||||
# checkpoint config and become the "auto" behavior
|
||||
if self.kv_cache_dtype == "fp8_e5m2":
|
||||
raise ValueError("fp8_e5m2 kv-cache is not supported with "
|
||||
"fp8 checkpoints.")
|
||||
# If quantization is enabled, we make "k_scale" and "v_scale"
|
||||
# parameters so that it can be loaded from the model checkpoint.
|
||||
# The k/v_scale will then be converted back to native float32
|
||||
# values after weight loading.
|
||||
self.quant_method = quant_method
|
||||
self.quant_method.create_weights(self)
|
||||
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
dtype = torch.get_default_dtype()
|
||||
if attn_backend is None:
|
||||
if vllm_version_is("0.10.2"):
|
||||
self.attn_backend = get_attn_backend(head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
is_attention_free,
|
||||
use_mla=use_mla,
|
||||
use_sfa=use_sfa,
|
||||
has_sink=self.has_sink)
|
||||
else:
|
||||
self.attn_backend = get_attn_backend(head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla=use_mla,
|
||||
use_sfa=use_sfa,
|
||||
has_sink=self.has_sink)
|
||||
else:
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
impl_cls = self.attn_backend.get_impl_cls()
|
||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **extra_impl_args)
|
||||
self.backend = backend_name_to_enum(self.attn_backend.get_name())
|
||||
self.dtype = dtype
|
||||
|
||||
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||
# torch.compile works by registering the attention as one giant
|
||||
# opaque custom op. For other platforms, we directly call them
|
||||
# and let torch.compile handle them.
|
||||
self.use_direct_call = not current_platform.opaque_attention_op()
|
||||
|
||||
self.use_output = self.attn_backend.accept_output_buffer
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.layer_name = prefix
|
||||
self.attn_type = attn_type
|
||||
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
validate_kv_sharing_target(
|
||||
prefix,
|
||||
kv_sharing_target_layer_name,
|
||||
compilation_config.static_forward_context,
|
||||
)
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
# use a placeholder kv cache tensor during init, which will be replaced
|
||||
# by bind_kv_cache
|
||||
# this variable will not be accessed if use_direct_call is True
|
||||
self.kv_cache = [
|
||||
torch.tensor([]) for _ in range(get_current_vllm_config(
|
||||
).parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.query_quant = None
|
||||
|
||||
|
||||
vllm.attention.Attention = AscendAttention
|
||||
@@ -0,0 +1,181 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# mypy: ignore-errors
|
||||
from functools import cache
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import vllm
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.selector import (backend_name_to_enum,
|
||||
get_global_forced_attn_backend)
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
if vllm_version_is("0.10.2"):
|
||||
|
||||
def get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool = False,
|
||||
use_mla: bool = False,
|
||||
use_sfa: bool = False,
|
||||
has_sink: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
|
||||
# value to be returned from the cache if the value changes between calls.
|
||||
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
|
||||
# private function.
|
||||
return _cached_get_attn_backend(
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
block_size=block_size,
|
||||
is_attention_free=is_attention_free,
|
||||
use_v1=envs.VLLM_USE_V1,
|
||||
use_mla=use_mla,
|
||||
use_sfa=use_sfa,
|
||||
has_sink=has_sink,
|
||||
)
|
||||
|
||||
@cache
|
||||
def _cached_get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool,
|
||||
use_v1: bool = False,
|
||||
use_mla: bool = False,
|
||||
use_sfa: bool = False,
|
||||
has_sink: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
# If there are no attention layers (e.g. we are running Mamba),
|
||||
# use the placeholder NO_ATTENTION
|
||||
if is_attention_free:
|
||||
from vllm.attention.backends.placeholder_attn import \
|
||||
PlaceholderAttentionBackend
|
||||
return PlaceholderAttentionBackend
|
||||
|
||||
# Check whether a particular choice of backend was
|
||||
# previously forced.
|
||||
#
|
||||
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
||||
# ENVIRONMENT VARIABLE.
|
||||
selected_backend = None
|
||||
backend_by_global_setting: Optional[_Backend] = (
|
||||
get_global_forced_attn_backend())
|
||||
if backend_by_global_setting is not None:
|
||||
selected_backend = backend_by_global_setting
|
||||
else:
|
||||
# Check the environment variable and override if specified
|
||||
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||
if backend_by_env_var is not None:
|
||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||
if selected_backend is None:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: '{backend_by_env_var}'. "
|
||||
f"Valid backends are: {list(_Backend.__members__.keys())}"
|
||||
)
|
||||
|
||||
# get device-specific attn_backend
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
selected_backend, head_size, dtype, kv_cache_dtype, block_size,
|
||||
use_v1, use_mla, use_sfa, has_sink)
|
||||
if not attention_cls:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {current_platform.device_name}"
|
||||
)
|
||||
return resolve_obj_by_qualname(attention_cls)
|
||||
else:
|
||||
|
||||
def get_attn_backend( # type: ignore[misc]
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
use_mla: bool = False,
|
||||
use_sfa: bool = False,
|
||||
has_sink: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
|
||||
# value to be returned from the cache if the value changes between calls.
|
||||
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
|
||||
# private function.
|
||||
return _cached_get_attn_backend(
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
block_size=block_size,
|
||||
use_v1=envs.VLLM_USE_V1,
|
||||
use_mla=use_mla,
|
||||
use_sfa=use_sfa,
|
||||
has_sink=has_sink,
|
||||
)
|
||||
|
||||
@cache
|
||||
def _cached_get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
use_v1: bool = False,
|
||||
use_mla: bool = False,
|
||||
use_sfa: bool = False,
|
||||
has_sink: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
# Check whether a particular choice of backend was
|
||||
# previously forced.
|
||||
#
|
||||
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
||||
# ENVIRONMENT VARIABLE.
|
||||
selected_backend = None
|
||||
backend_by_global_setting: Optional[_Backend] = (
|
||||
get_global_forced_attn_backend())
|
||||
if backend_by_global_setting is not None:
|
||||
selected_backend = backend_by_global_setting
|
||||
else:
|
||||
# Check the environment variable and override if specified
|
||||
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||
if backend_by_env_var is not None:
|
||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||
if selected_backend is None:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: '{backend_by_env_var}'. "
|
||||
f"Valid backends are: {list(_Backend.__members__.keys())}"
|
||||
)
|
||||
|
||||
# get device-specific attn_backend
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
selected_backend, head_size, dtype, kv_cache_dtype, block_size,
|
||||
use_v1, use_mla, use_sfa, has_sink)
|
||||
if not attention_cls:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {current_platform.device_name}"
|
||||
)
|
||||
return resolve_obj_by_qualname(attention_cls)
|
||||
|
||||
|
||||
vllm.attention.get_attn_backend = get_attn_backend
|
||||
vllm.attention.selector._cached_get_attn_backend = _cached_get_attn_backend
|
||||
110
vllm_ascend/patch/worker/patch_common/patch_attentionspec.py
Normal file
110
vllm_ascend/patch/worker/patch_common/patch_attentionspec.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import vllm
|
||||
from typing_extensions import Self
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import cdiv, get_dtype_size
|
||||
from vllm.v1.core.single_type_kv_cache_manager import (FullAttentionManager,
|
||||
spec_manager_map)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AttentionSpec(KVCacheSpec):
|
||||
num_kv_heads: int
|
||||
head_size: int
|
||||
dtype: torch.dtype
|
||||
use_mla: bool
|
||||
use_sfa: bool
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
# For MLA we only store a single latent vector
|
||||
coef = 1 if self.use_mla else 2
|
||||
sfa_bytes = 128 * self.block_size * get_dtype_size(
|
||||
self.dtype) if self.use_sfa else 0
|
||||
|
||||
return coef * self.block_size * self.num_kv_heads * self.head_size \
|
||||
* get_dtype_size(self.dtype) + sfa_bytes
|
||||
|
||||
|
||||
vllm.v1.kv_cache_interface.AttentionSpec = AttentionSpec
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AscendFullAttentionSpec(FullAttentionSpec, AttentionSpec):
|
||||
sliding_window: Optional[int] = None
|
||||
attention_chunk_size: Optional[int] = None
|
||||
"""
|
||||
When hybrid allocator is disabled and the model contains both full
|
||||
attention layers and sliding window attention layers, sliding
|
||||
window attention are regarded as full attention in KV cache manager
|
||||
(blocks are allocated for all tokens), while computed as sliding window
|
||||
attention in model runner.
|
||||
In this case, we use FullAttentionSpec and record the sliding window size.
|
||||
Default to None for not using sliding window attention.
|
||||
"""
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
dcp_world_size = \
|
||||
vllm_config.parallel_config.decode_context_parallel_size
|
||||
# Note(hc): each dcp rank only need save
|
||||
# (max_model_len//dcp_world_size) tokens locally.
|
||||
if dcp_world_size > 1:
|
||||
max_model_len = cdiv(max_model_len, dcp_world_size)
|
||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||
|
||||
@classmethod
|
||||
def merge_window_sizes(cls, window_sizes: set[int]) -> Optional[int]:
|
||||
if len(window_sizes) == 0:
|
||||
return None
|
||||
elif len(window_sizes) == 1:
|
||||
return window_sizes.pop()
|
||||
else:
|
||||
raise ValueError(
|
||||
"All attention layers in the same KV cache group must have the "
|
||||
"same window size.")
|
||||
|
||||
@classmethod
|
||||
def merge(cls, specs: list[Self]) -> Self:
|
||||
"""
|
||||
Merge a list of FullAttentionSpec objects into a single
|
||||
FullAttentionSpec object.
|
||||
"""
|
||||
assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
|
||||
"All attention layers in the same KV cache group must be "
|
||||
"FullAttentionSpec.")
|
||||
|
||||
sliding_window = set(spec.sliding_window for spec in specs
|
||||
if spec.sliding_window is not None)
|
||||
attention_chunk_size = set(spec.attention_chunk_size for spec in specs
|
||||
if spec.attention_chunk_size is not None)
|
||||
merged_spec = cls(
|
||||
block_size=specs[0].block_size,
|
||||
num_kv_heads=specs[0].num_kv_heads,
|
||||
head_size=specs[0].head_size,
|
||||
dtype=specs[0].dtype,
|
||||
use_mla=specs[0].use_mla,
|
||||
use_sfa=specs[0].use_sfa,
|
||||
sliding_window=cls.merge_window_sizes(sliding_window),
|
||||
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
|
||||
)
|
||||
for spec in specs:
|
||||
for f in fields(AttentionSpec):
|
||||
assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
|
||||
"All attention layers in the same KV cache group must have "
|
||||
"the same attention spec.")
|
||||
assert (
|
||||
(merged_spec.sliding_window is not None) +
|
||||
(merged_spec.attention_chunk_size is not None) <= 1
|
||||
), ("Model with both sliding window layers and chunked local attention "
|
||||
"layers is not supported.")
|
||||
return merged_spec
|
||||
|
||||
|
||||
spec_manager_map.update({AscendFullAttentionSpec: FullAttentionManager})
|
||||
|
||||
vllm.v1.kv_cache_interface.FullAttentionSpec = AscendFullAttentionSpec
|
||||
@@ -1,147 +0,0 @@
|
||||
"""
|
||||
Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
This file is a part of the vllm-ascend project.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
import vllm
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
split_tensor_along_last_dim)
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
|
||||
_HCOMM_INFO = None
|
||||
|
||||
|
||||
class AscendRowParallelLinear(RowParallelLinear):
|
||||
"""
|
||||
AscendRowParallelLinear is a custom implementation of RowParallelLinear
|
||||
that overrides the forward method to handle Ascend-specific operations.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize the AscendRowParallelLinear layer.
|
||||
|
||||
Args:
|
||||
*args: Variable length argument list.
|
||||
**kwargs: Arbitrary keyword arguments.
|
||||
"""
|
||||
tp_group = get_tp_group().device_group
|
||||
hcomm_info = self.get_hcomm_info(tp_group)
|
||||
self.hcomm_info = hcomm_info
|
||||
super().__init__(*args, **kwargs)
|
||||
self.weight_t = self.weight.t()
|
||||
|
||||
@staticmethod
|
||||
def get_hcomm_info(group: ProcessGroup) -> str:
|
||||
"""Get the HCCL communication information for the given group.
|
||||
|
||||
Args:
|
||||
group (ProcessGroup): The process group for which to get the HCCL communication info.
|
||||
|
||||
Returns:
|
||||
str: The HCCL communication name for the given group.
|
||||
"""
|
||||
global _HCOMM_INFO
|
||||
if _HCOMM_INFO is not None:
|
||||
return _HCOMM_INFO
|
||||
|
||||
rank = torch.distributed.get_rank(group)
|
||||
if torch.__version__ > "2.0":
|
||||
global_rank = torch.distributed.get_global_rank(group, rank)
|
||||
_HCOMM_INFO = group._get_backend(
|
||||
torch.device("npu")).get_hccl_comm_name(global_rank)
|
||||
|
||||
else:
|
||||
_HCOMM_INFO = group.get_hccl_comm_name(rank)
|
||||
return _HCOMM_INFO
|
||||
|
||||
def forward(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
"""Forward pass for the AscendRowParallelLinear layer.
|
||||
|
||||
Args:
|
||||
input_ (torch.Tensor): the input tensor to the layer.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
The output tensor after applying the linear transformation,
|
||||
and optionally the bias if `return_bias` is True.
|
||||
"""
|
||||
input_parallel = self.calc_input(input_)
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
output = self.calc_output(input_parallel)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
def calc_input(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
"""Calculate the input tensor for parallel processing.
|
||||
|
||||
Args:
|
||||
input_ (torch.Tensor): the input tensor to be processed.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The input tensor split along the last dimension
|
||||
for tensor model parallelism, or the original input if not parallel.
|
||||
"""
|
||||
if self.input_is_parallel:
|
||||
return input_
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
return splitted_input[tp_rank].contiguous()
|
||||
|
||||
def calc_output(self, input_parallel: torch.Tensor) -> torch.Tensor:
|
||||
"""Calculate the output tensor of forward by considering
|
||||
fusing communication and computation.
|
||||
|
||||
Args:
|
||||
input_parallel (_type_): the input tensor to be processed in parallel.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the output tensor after applying the linear transformation
|
||||
and optionally handle communication between tensor model parallel ranks.
|
||||
"""
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output = torch_npu.npu_mm_all_reduce_base(input_parallel,
|
||||
self.weight_t,
|
||||
self.hcomm_info,
|
||||
bias=bias_)
|
||||
else:
|
||||
output = self.quant_method.apply(self, input_parallel, bias=bias_)
|
||||
return output
|
||||
|
||||
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE:
|
||||
logger.info("AscendRowParallelLinear: Matmul all-reduce is enabled. ")
|
||||
vllm.model_executor.layers.linear.RowParallelLinear = AscendRowParallelLinear
|
||||
@@ -1,29 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import vllm
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.lora.layers import VocabParallelEmbeddingWithLoRA
|
||||
from vllm.lora.utils import _all_lora_classes
|
||||
|
||||
from vllm_ascend.ops.vocab_parallel_embedding import \
|
||||
AscendVocabParallelEmbedding
|
||||
|
||||
|
||||
class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA):
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendVocabParallelEmbedding
|
||||
|
||||
|
||||
# Patch for lora register_model issue after overriding VocabParallelEmbedding class (#2515)
|
||||
_all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes = _all_lora_classes
|
||||
16
vllm_ascend/patch/worker/patch_common/patch_triton.py
Normal file
16
vllm_ascend/patch/worker/patch_common/patch_triton.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import vllm.model_executor.layers.fla.ops.chunk
|
||||
import vllm.model_executor.layers.fla.ops.fused_recurrent
|
||||
import vllm.model_executor.layers.fla.ops.layernorm_guard
|
||||
import vllm.model_executor.layers.mamba.ops.causal_conv1d
|
||||
|
||||
from vllm_ascend.ops.casual_conv1d import (causal_conv1d_fn,
|
||||
causal_conv1d_update_npu)
|
||||
from vllm_ascend.ops.fla import LayerNormFn, torch_chunk_gated_delta_rule
|
||||
from vllm_ascend.ops.sigmoid_gating import \
|
||||
fused_recurrent_gated_delta_rule_fwd_kernel
|
||||
|
||||
vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_update = causal_conv1d_update_npu
|
||||
vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_fn = causal_conv1d_fn
|
||||
vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel
|
||||
vllm.model_executor.layers.fla.ops.layernorm_guard.LayerNormFn = LayerNormFn
|
||||
vllm.model_executor.layers.fla.ops.chunk.chunk_gated_delta_rule = torch_chunk_gated_delta_rule
|
||||
44
vllm_ascend/patch/worker/patch_common/patch_weight_loader.py
Normal file
44
vllm_ascend/patch/worker/patch_common/patch_weight_loader.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import GiB_bytes
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
# This method creates unquantized linear weights.
|
||||
# The weights are not quantized, and they are not sharded.
|
||||
# The amount of memory allocated for the weights is
|
||||
# sum(output_partition_sizes) * input_size_per_partition.
|
||||
try:
|
||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
logger.error("Failed to create unquantized linear weights: %s", e)
|
||||
if torch.cuda.is_available():
|
||||
logger.debug("CUDA device: %s", torch.cuda.current_device())
|
||||
logger.debug("Allocated: %.2f GiB",
|
||||
torch.cuda.memory_allocated() / GiB_bytes)
|
||||
logger.debug("Reserved: %.2f GiB",
|
||||
torch.cuda.memory_reserved() / GiB_bytes)
|
||||
raise RuntimeError(
|
||||
"Failed to create unquantized linear weights. "
|
||||
"This may be caused by insufficient memory to allocate "
|
||||
"the weight.") from e
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
|
||||
if not vllm_version_is("0.10.2"):
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
UnquantizedLinearMethod.create_weights = create_weights
|
||||
@@ -16,6 +16,7 @@
|
||||
#
|
||||
|
||||
import gc
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
@@ -31,7 +32,7 @@ from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config,
|
||||
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
|
||||
delete_torchair_cache_file)
|
||||
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, is_310p,
|
||||
update_aclgraph_sizes)
|
||||
update_aclgraph_sizes, vllm_version_is)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
@@ -128,11 +129,43 @@ class NPUPlatform(Platform):
|
||||
model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
cache_config = vllm_config.cache_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
ascend_scheduler_config = ascend_config.ascend_scheduler_config
|
||||
if vllm_version_is("0.10.2"):
|
||||
structured_outputs_config = vllm_config.decoding_config
|
||||
else:
|
||||
structured_outputs_config = vllm_config.structured_outputs_config
|
||||
|
||||
if (model_config is not None and not model_config.use_mla
|
||||
and not scheduler_config.async_scheduling):
|
||||
logger.info(
|
||||
"Non-MLA LLMs forcibly disable the chunked prefill feature,"
|
||||
"as the performance of operators supporting this feature "
|
||||
"functionality is currently suboptimal.")
|
||||
if not model_config.is_multimodal_model and \
|
||||
structured_outputs_config.backend == "auto" and \
|
||||
not getattr(scheduler_config, "scheduler_delay_factor", 0) > 0 and \
|
||||
not scheduler_config.send_delta_data and \
|
||||
scheduler_config.policy == "fcfs":
|
||||
ascend_scheduler_config.enabled = True
|
||||
chunked_prefill_enabled_in_ascend_scheduler = getattr(
|
||||
ascend_scheduler_config, "enable_chunked_prefill", False)
|
||||
if chunked_prefill_enabled_in_ascend_scheduler:
|
||||
logger.warning(
|
||||
"Chunked prefill feature is enabled in ascend_scheduler,"
|
||||
"but note that the operator supporting this feature "
|
||||
"would lead to performance degradation.")
|
||||
# In this situation, max_num_batched_tokens would have been rewritten.
|
||||
# So we must make sure max_num_batched_tokens is not smaller than max_model_len.
|
||||
if (scheduler_config.max_num_batched_tokens
|
||||
< scheduler_config.max_model_len
|
||||
and not chunked_prefill_enabled_in_ascend_scheduler):
|
||||
scheduler_config.max_num_batched_tokens = scheduler_config.max_model_len
|
||||
|
||||
kv_cache_dtype = vllm_config.additional_config.get(
|
||||
"kv_cache_dtype", None)
|
||||
if kv_cache_dtype is not None:
|
||||
vllm_config.cache_config.cache_dtype = kv_cache_dtype
|
||||
|
||||
if model_config is None:
|
||||
logger.warning("Model config is missing. This may indicate "
|
||||
"that we are running a test case")
|
||||
@@ -148,23 +181,13 @@ class NPUPlatform(Platform):
|
||||
|
||||
compilation_config.cudagraph_num_of_warmups = 1
|
||||
|
||||
# TODO: make vllm support oot platform to set `compilation_config.cudagraph_mode`
|
||||
# if cudagraph_mode is not explicitly set by users, set default value
|
||||
if compilation_config.level == CompilationLevel.PIECEWISE:
|
||||
compilation_config.cudagraph_mode = \
|
||||
CUDAGraphMode.PIECEWISE
|
||||
elif compilation_config.level not in [
|
||||
if compilation_config.level not in [
|
||||
CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE
|
||||
]:
|
||||
logger.warning(
|
||||
"NPU does not support %s compilation level. Setting CUDAGraphMode to NONE",
|
||||
compilation_config.level)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
else:
|
||||
logger.warning(
|
||||
"compilation_config.level = CompilationLevel.NO_COMPILATION is set, Setting CUDAGraphMode to NONE"
|
||||
)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
# set CUDAGraphMode to None when torchair is enabled, no mather what compilation_config.level is.
|
||||
if ascend_config.torchair_graph_config.enabled:
|
||||
@@ -185,18 +208,22 @@ class NPUPlatform(Platform):
|
||||
"and use_cached_kv_cache_bytes in torchair_graph_config.")
|
||||
delete_torchair_cache_file()
|
||||
|
||||
if parallel_config.distributed_executor_backend == "ray":
|
||||
logger.warning(
|
||||
"Ray distributed executor backend is not compatible with ACL Graph mode "
|
||||
"right now. Setting CUDAGraphMode to NONE")
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
# set cudaprah sizes before extending `compilation_config.splitting_ops`
|
||||
vllm_config._set_cudagraph_sizes()
|
||||
|
||||
# TODO: Full graph is fully supported later, and the default value will be set to full graph.
|
||||
if not vllm_version_is("0.10.2"):
|
||||
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
|
||||
compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
|
||||
# TODO: Currently MLA does not support FULL_DECODE_ONLY, remove the second condition
|
||||
# after MLA being supported
|
||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE or (
|
||||
compilation_config.cudagraph_mode
|
||||
== CUDAGraphMode.FULL_DECODE_ONLY and model_config is not None
|
||||
and model_config.use_mla):
|
||||
logger.info(
|
||||
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
|
||||
"using only ACL Graph mode")
|
||||
@@ -204,9 +231,28 @@ class NPUPlatform(Platform):
|
||||
"When enabling piecewise aclgraph, please make sure compilation_config.level == CompilationLevel.PIECEWISE and compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE"
|
||||
compilation_config.set_splitting_ops_for_v1()
|
||||
compilation_config.use_inductor = False
|
||||
compilation_config.splitting_ops.extend(
|
||||
["vllm.unified_ascend_attention_with_output"])
|
||||
compilation_config.splitting_ops.extend([
|
||||
"vllm.unified_ascend_attention_with_output", "vllm.mla_forward"
|
||||
])
|
||||
update_aclgraph_sizes(vllm_config)
|
||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
||||
logger.info(
|
||||
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
|
||||
"using only ACL Graph mode")
|
||||
compilation_config.use_inductor = False
|
||||
warning_message = """\033[91m
|
||||
**********************************************************************************
|
||||
* WARNING: You have enabled the *full graph* feature.
|
||||
* This is an early experimental stage and may involve various unknown issues.
|
||||
* A known problem is that capturing too many batch sizes can lead to OOM
|
||||
* (Out of Memory) errors or inference hangs. If you encounter such issues,
|
||||
* consider reducing `gpu_memory_utilization` or manually specifying a smaller
|
||||
* batch size for graph capture.
|
||||
* For more details, please refer to:
|
||||
* https://docs.vllm.ai/en/stable/configuration/conserving_memory.html#reduce-cuda-graphs
|
||||
**********************************************************************************\033[0m
|
||||
"""
|
||||
logger.warning(warning_message)
|
||||
else:
|
||||
logger.info(
|
||||
"%s cudagraph_mode is not support on NPU. falling back to NONE",
|
||||
@@ -215,7 +261,9 @@ class NPUPlatform(Platform):
|
||||
compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
|
||||
if parallel_config and parallel_config.worker_cls == "auto":
|
||||
if ascend_config.torchair_graph_config.enabled:
|
||||
# TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm.
|
||||
os.environ["VLLM_ALL2ALL_BACKEND"] = "flashinfer_all2allv"
|
||||
if ascend_config.torchair_graph_config.enabled or ascend_config.enable_shared_expert_dp:
|
||||
parallel_config.worker_cls = "vllm_ascend.torchair.torchair_worker.NPUTorchairWorker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
|
||||
@@ -223,6 +271,7 @@ class NPUPlatform(Platform):
|
||||
if cache_config:
|
||||
if cache_config.block_size is None:
|
||||
cache_config.block_size = 128
|
||||
|
||||
if cache_config.enable_prefix_caching and cache_config.block_size != 128:
|
||||
logger.warning(
|
||||
"If prefix caching is enabled, block size must be set to 128."
|
||||
@@ -242,12 +291,6 @@ class NPUPlatform(Platform):
|
||||
ascend_config.ascend_scheduler_config)
|
||||
vllm_config.scheduler_config = ascend_scheduler_config
|
||||
|
||||
if compilation_config.pass_config.enable_sequence_parallelism:
|
||||
if not parallel_config.enable_expert_parallel or vllm_config.model_config.hf_config.model_type != "qwen3_moe":
|
||||
raise NotImplementedError(
|
||||
"For better performance in Qwen3 MoE, SP only works exclusively with MC2, AllToAll, and AllToAllV."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls,
|
||||
selected_backend,
|
||||
@@ -257,27 +300,40 @@ class NPUPlatform(Platform):
|
||||
block_size,
|
||||
use_v1,
|
||||
use_mla,
|
||||
use_sfa,
|
||||
has_sink=False):
|
||||
if not use_v1:
|
||||
raise ValueError("vLLM Ascend does not support V0 engine.")
|
||||
|
||||
use_torchair = get_ascend_config().torchair_graph_config.enabled
|
||||
ascend_config = get_ascend_config()
|
||||
|
||||
if use_mla and ascend_config.enable_shared_expert_dp:
|
||||
if use_mla and not use_sfa:
|
||||
return "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend"
|
||||
if use_mla and use_sfa:
|
||||
return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend"
|
||||
|
||||
use_torchair = ascend_config.torchair_graph_config.enabled
|
||||
# choose attention backend based on use_mla and use_torchair
|
||||
backend_map = {
|
||||
(True, True):
|
||||
(True, False, True):
|
||||
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend",
|
||||
(True, False):
|
||||
(True, False, False):
|
||||
"vllm_ascend.attention.mla_v1.AscendMLABackend",
|
||||
(False, True):
|
||||
(False, False, True):
|
||||
"vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend",
|
||||
(False, False):
|
||||
"vllm_ascend.attention.attention_v1.AscendAttentionBackend"
|
||||
(False, False, False):
|
||||
"vllm_ascend.attention.attention_v1.AscendAttentionBackend",
|
||||
(True, True, False):
|
||||
"vllm_ascend.attention.sfa_v1.AscendSFABackend",
|
||||
(True, True, True):
|
||||
"vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend",
|
||||
}
|
||||
return backend_map[(use_mla, use_torchair)]
|
||||
return backend_map[(use_mla, use_sfa, use_torchair)]
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
return "vllm_ascend.lora.punica_wrapper.punica_npu.PunicaWrapperNPU"
|
||||
return "vllm_ascend.lora.punica_npu.PunicaWrapperNPU"
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(cls,
|
||||
@@ -343,3 +399,11 @@ class NPUPlatform(Platform):
|
||||
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
return pg
|
||||
|
||||
@classmethod
|
||||
def support_hybrid_kv_cache(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def support_static_graph_mode(cls) -> bool:
|
||||
return True
|
||||
|
||||
@@ -1,184 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, QuantizationConfig)
|
||||
|
||||
|
||||
# func refers to vocabParallelEmbedding.__init__
|
||||
def wrapper_vocab_parallel_embedding_init(func):
|
||||
|
||||
def init(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
func(
|
||||
self,
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
params_dtype,
|
||||
org_num_embeddings,
|
||||
padding_size,
|
||||
quant_config,
|
||||
prefix,
|
||||
)
|
||||
# TODO: Contact vLLM maintainers to add a `params_dtype` attribute to the `VocabParallelEmbedding` class.
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
|
||||
return init
|
||||
|
||||
|
||||
# func refers to RMSNorm.__init__
|
||||
def wrapper_rmsnorm_init(func):
|
||||
|
||||
def init(self, hidden_size: int, **extra_args) -> None:
|
||||
func(self, hidden_size, **extra_args)
|
||||
self.ignore_anti = True
|
||||
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
|
||||
requires_grad=False)
|
||||
|
||||
return init
|
||||
|
||||
|
||||
# func refers to RMSNorm.forward_oot
|
||||
def wrapper_rmsnorm_forward_oot(func):
|
||||
|
||||
def _rmsnorm_forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if not self.ignore_anti:
|
||||
if residual is not None:
|
||||
residual += x
|
||||
out = torch_npu._npu_quant_rms_norm(
|
||||
residual,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.input_scale,
|
||||
self.input_offset,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return out, residual
|
||||
out = torch_npu._npu_quant_rms_norm(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.input_scale,
|
||||
self.input_offset,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return out
|
||||
|
||||
if residual is not None:
|
||||
x, residual = func(self, x, residual)
|
||||
return x.add_(self.bias), residual
|
||||
|
||||
return func(self, x).add_(self.bias)
|
||||
|
||||
return _rmsnorm_forward_oot
|
||||
|
||||
|
||||
MODEL_LAYER_MAPPING = {
|
||||
"LlamaModel": {
|
||||
"attn": {
|
||||
"layer_attr": "self_attn",
|
||||
"proj_attr": "qkv_proj",
|
||||
"norm_attr": "input_layernorm",
|
||||
"unquantized_type": UnquantizedLinearMethod,
|
||||
},
|
||||
"mlp": {
|
||||
"layer_attr": "mlp",
|
||||
"proj_attr": "gate_up_proj",
|
||||
"norm_attr": "post_attention_layernorm",
|
||||
"unquantized_type": UnquantizedLinearMethod,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def wrapper_load_model(func):
|
||||
|
||||
def postprocess_loading(self) -> None:
|
||||
func(self)
|
||||
|
||||
def process_layer(layer, idx, mapping):
|
||||
|
||||
def process_module(module_cfg, layer_obj):
|
||||
if module_cfg is None:
|
||||
return
|
||||
|
||||
module_obj = getattr(layer_obj, module_cfg["layer_attr"], None)
|
||||
if module_obj is None:
|
||||
return
|
||||
|
||||
proj_attr = module_cfg["proj_attr"]
|
||||
if callable(proj_attr):
|
||||
proj = proj_attr(module_obj, idx)
|
||||
else:
|
||||
proj = getattr(module_obj, proj_attr, None)
|
||||
|
||||
norm = getattr(layer_obj, module_cfg["norm_attr"], None)
|
||||
|
||||
if proj is None or norm is None:
|
||||
return
|
||||
|
||||
norm.ignore_anti = isinstance(proj.quant_method,
|
||||
module_cfg["unquantized_type"])
|
||||
if not norm.ignore_anti:
|
||||
for param_name in ["input_scale", "input_offset"]:
|
||||
if hasattr(proj, param_name):
|
||||
param = getattr(proj, param_name)
|
||||
norm.register_parameter(
|
||||
param_name,
|
||||
torch.nn.Parameter(param.clone(),
|
||||
requires_grad=False))
|
||||
|
||||
process_module(mapping.get("attn"), layer)
|
||||
process_module(mapping.get("mlp"), layer)
|
||||
|
||||
model_type = self.model.model.__class__.__name__
|
||||
mapping = MODEL_LAYER_MAPPING.get(model_type)
|
||||
|
||||
if not mapping:
|
||||
logger.info(
|
||||
f"Warning: Model type '{model_type}' not found in MODEL_LAYER_MAPPING. Skipping layer mapping."
|
||||
)
|
||||
return
|
||||
|
||||
for idx, layer in enumerate(self.model.model.layers):
|
||||
process_layer(layer, idx, mapping)
|
||||
|
||||
if isinstance(self.model.model.norm, RMSNorm):
|
||||
self.model.model.norm.ignore_anti = True
|
||||
|
||||
return postprocess_loading
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user