[Model] Support DeepSeek-V4

This commit is contained in:
chenxb002
2026-04-24 09:50:34 +08:00
commit b9925203b8
172 changed files with 44780 additions and 0 deletions

View File

@@ -0,0 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm import ModelRegistry
def register_model():
from .deepseek_v4 import MLUDeepseekV4ForCausalLM # noqa: F401
ModelRegistry.register_model(
"DeepseekV4ForCausalLM",
"vllm_mlu.model_executor.models.deepseek_v4:MLUDeepseekV4ForCausalLM")

View File

@@ -0,0 +1,192 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from math import lcm
from typing import TYPE_CHECKING
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.model_executor.models.config import (HybridAttentionMambaModelConfig,
MambaModelConfig)
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
@classmethod
def vllm__module_executor__models__config__HybridAttentionMambaModelConfig__verify_and_update_config(
cls,
vllm_config: "VllmConfig"
) -> None:
"""
Ensure that page size of attention layers is greater than or
equal to the mamba layers. If not, automatically set the attention
block size to ensure that it is. If the attention page size is
strictly greater than the mamba page size, we pad the mamba page size
to make them equal.
Args:
vllm_config: vLLM Config
"""
# Save the user input before it gets modified by MambaModelConfig
mamba_block_size = vllm_config.cache_config.mamba_block_size
# Enable FULL_AND_PIECEWISE by default
MambaModelConfig.verify_and_update_config(vllm_config)
cache_config = vllm_config.cache_config
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
if cache_config.cache_dtype == "auto":
kv_cache_dtype = model_config.dtype
else:
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# get attention page size (for 1 token)
# Attention backend constraints:
# - FlashAttention (FA) requires block size to be multiple of 16
# - MLA (Multi-head Latent Attention) requires larger alignment:
# * CUTLASS_MLA backend: kernel_block_size 128 alignment
# * Other MLA backends: kernel_block_size 64 alignment
if model_config.use_mla:
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
attn_page_size_1_token = MLAAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
).page_size_bytes
else:
kernel_block_alignment_size = 16
if (
current_platform.is_device_capability(100)
and model_config.get_head_size() == 256
and (
envs.VLLM_ATTENTION_BACKEND is None
or envs.VLLM_ATTENTION_BACKEND == "FLASHINFER"
)
):
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that`
# head size 256 and block size 16 is not supported on blackwell.
kernel_block_alignment_size = 32
attn_page_size_1_token = FullAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
).page_size_bytes
model_cls, _ = ModelRegistry.resolve_model_cls(
model_config.architecture,
model_config=model_config,
)
# get mamba page size
mamba_page_size = MambaSpec(
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
block_size=model_config.max_model_len,
).page_size_bytes
# Model may be marked as is_hybrid
# but mamba is skipped via config,
# return directly
if mamba_page_size == 0:
return
if cache_config.enable_prefix_caching:
# With prefix caching, select attention block size to
# optimize for mamba kernel performance
# Mamba2 SSD kernel uses a chunk_size, e.g. 256
# Align the block to the kernel: use lowest multiple of chunk_size
# of attention tokens that would fit mamba_page_size:
# e.g. for mamba page size = 788kB
# attn_1_token = 2kB -> fits ~394 tokens
# then round up to a mulitple of 256 -> 512 tokens
# End result:
# attn_block_size = 512
# mamba_block_size = 512 (aligned to a multiple of chunk_size)
# TODO(tdoublep): this constraint can be relaxed fairly
# easily by changing the way we layout chunks in the
# mamba2 kernels.
base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
cache_config.mamba_block_size = attn_block_size
else:
# Without prefix caching, select minimum valid attention block size
# to minimize mamba state padding
# Calculate minimum attention block size that satisfies both:
# 1. Backend alignment requirements (kernel_block_alignment_size)
# 2. Mamba page size compatibility (attn_page_size >= mamba_page_size)
attn_block_size = kernel_block_alignment_size * cdiv(
mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: support qwen3-next
'''
if (vllm_config.mlu_config.enable_mamba_split_page_size):
vllm_config.mlu_config.mamba_to_attn_block_ratio = cdiv(attn_block_size, cache_config.block_size)
cache_config.mamba_page_size_padded = cache_config.block_size * attn_page_size_1_token
return
'''
==================
End of MLU Hijack
==================
'''
# override attention block size if either (a) the
# user has not set it or (b) the user has set it
# too small.
if cache_config.block_size is None or cache_config.block_size < attn_block_size:
cache_config.block_size = attn_block_size
logger.info(
"Setting attention block size to %d tokens "
"to ensure that attention page size is >= mamba page size.",
attn_block_size,
)
# compute new attention page size
attn_page_size = cache_config.block_size * attn_page_size_1_token
assert attn_page_size >= mamba_page_size
if attn_page_size == mamba_page_size:
# don't need to pad mamba page size
return
# pad mamba page size to exactly match attention
if (
cache_config.mamba_page_size_padded is None
or cache_config.mamba_page_size_padded != attn_page_size
):
cache_config.mamba_page_size_padded = attn_page_size
mamba_padding_pct = (
100 * (attn_page_size - mamba_page_size) / mamba_page_size
)
logger.info(
"Padding mamba page size by %.2f%% to ensure "
"that mamba page size and attention page size are "
"exactly equal.",
mamba_padding_pct,
)
MluHijackObject.apply_hijack(HybridAttentionMambaModelConfig,
HybridAttentionMambaModelConfig.verify_and_update_config,
vllm__module_executor__models__config__HybridAttentionMambaModelConfig__verify_and_update_config)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,607 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import (
Any, List, Tuple, Optional, Dict, Union, ClassVar, Literal,
Protocol, overload, runtime_checkable)
from typing_extensions import TypeIs
import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm.config import VllmConfig
from vllm.distributed.communication_op import (
tensor_model_parallel_all_gather,
tensor_model_parallel_all_gather_into_list,
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter,
)
from vllm.distributed import (
get_tp_group,
get_pp_group,
get_dp_group,
get_data_parallel_group_rank,
get_data_parallel_group_world_size,
get_dense_mlp_tp_world_size,
get_tp_world_world_size,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank,
get_logits_tp_world_size,
get_parallel_rank_with_group,
get_tp_world_group,
get_tp_world_rank,
GroupCoordinator,
)
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm_mlu.mlu_forward_context import MLUDPMetadata
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
from vllm_mlu.v1.attention.backends.utils import get_common_metadata
logger = init_logger(__name__)
# alias after refactor
DataParallelRuntimeParams = MLUDPMetadata
def enable_data_parallel():
return get_dp_group().world_size > 1
def enable_emb_logits_custom_parallel():
return get_logits_tp_world_size() != get_tensor_model_parallel_world_size()
def enable_dense_mlp_custom_parallel():
return get_dense_mlp_tp_world_size() != get_tp_world_world_size()
def get_runtime_infos_per_dp_group(
num_tokens: int, num_requests: int, all_prefill: bool, seq_lens: List[int],
device: torch.device, vllm_config: VllmConfig) -> Tuple[List[int], List[bool]]:
dp_tensor = torch.tensor([num_tokens, num_requests, int(all_prefill)]).to(device, non_blocking=True)
outputs = tensor_model_parallel_all_gather_into_list(dp_tensor, get_dp_group())
outputs = torch.cat(outputs).tolist() # d2h
dp_world_size = get_data_parallel_group_world_size()
dp_is_prefill, dp_query_lens, dp_group_bs, seq_len_per_batch = [], [], [], []
for i in range(0, 3 * dp_world_size, 3):
dp_query_lens.append(outputs[i])
dp_group_bs.append(outputs[i + 1])
dp_is_prefill.append(bool(outputs[i + 2]))
# Only run communication if mcc is enabled and is prefill.
if vllm_config.mlu_config.is_dpsk_mcc_enabled and all(dp_is_prefill):
assert len(seq_lens) == num_requests
seq_len_per_batch = [torch.empty([bs], dtype=dp_tensor.dtype, device=device) for bs in dp_group_bs]
seq_lens_tensor = torch.tensor(seq_lens, dtype=dp_tensor.dtype, device=device)
torch.distributed.all_gather(seq_len_per_batch, seq_lens_tensor, group=get_dp_group().device_group)
seq_len_per_batch=torch.cat(seq_len_per_batch).tolist()
else:
seq_len_per_batch = [0] * sum(dp_group_bs)
return dp_query_lens, dp_group_bs, dp_is_prefill, seq_len_per_batch
def get_deepseek_layer_split_list(
dp_query_lens: List[int], dp_group_bs: List[int]
) -> Tuple[Optional[List[int]], Optional[List[int]], Optional[List[int]]]:
if len(dp_query_lens) != len(dp_group_bs) or len(dp_query_lens) != get_data_parallel_group_world_size():
logger.warning(f"dp_query_lens length: {len(dp_query_lens)} != dp_group_bs length: {len(dp_group_bs)}, "
f"disable deepseek layer split")
return None, None, None
emb_query_lens, logits_batch_sizes, dense_attn_token_split_list = None, None, None
all_dp_query_lens, all_dp_group_bs = [], []
for i in range(len(dp_query_lens)):
all_dp_query_lens.extend([dp_query_lens[i]] * get_tensor_model_parallel_world_size())
all_dp_group_bs.extend([dp_group_bs[i]] * get_tensor_model_parallel_world_size())
if get_logits_tp_world_size() != get_tensor_model_parallel_world_size():
slice_start = get_tp_world_rank() // get_logits_tp_world_size() * get_logits_tp_world_size()
slice_end = slice_start + get_logits_tp_world_size()
emb_query_lens = all_dp_query_lens[slice_start:slice_end]
logits_batch_sizes = all_dp_group_bs[slice_start:slice_end]
if get_dense_mlp_tp_world_size() != get_tp_world_world_size():
slice_start = get_tp_world_rank() // get_dense_mlp_tp_world_size() * get_dense_mlp_tp_world_size()
slice_end = slice_start + get_dense_mlp_tp_world_size()
dense_attn_token_split_list = all_dp_query_lens[slice_start:slice_end]
return emb_query_lens, logits_batch_sizes, dense_attn_token_split_list
def get_dp_metadata(
num_tokens: int,
data_parallel_size: int,
data_parallel_rank: int,
tensor_parallel_size: int,
prefill_dispatch_use_RS_AG: bool,
) -> DataParallelRuntimeParams:
"""
Get dp params when dummy run or capture model graph. These two cases do not have
dp_params when forward call, because we do not want to hijack to much.
"""
dp_query_lens = [num_tokens] * data_parallel_size
in_prefill = get_forward_context().attn_metadata is None # dummy run
dp_is_prefill = [in_prefill] * data_parallel_size
emb_query_lens, logits_batch_sizes, dense_attn_token_split_list = None, None, None
if get_logits_tp_world_size() != get_tensor_model_parallel_world_size():
emb_query_lens = [num_tokens] * get_logits_tp_world_size()
logits_batch_sizes = None # dummy run and capture model does not contain logits
if get_dense_mlp_tp_world_size() != get_tp_world_world_size():
dense_attn_token_split_list = [num_tokens] * get_dense_mlp_tp_world_size()
return MLUDPMetadata.make_oot(data_parallel_rank,
data_parallel_size,
tensor_parallel_size,
dp_query_lens,
dp_is_prefill,
prefill_dispatch_use_RS_AG,
emb_query_lens=emb_query_lens,
logits_batch_sizes=logits_batch_sizes,
dense_attn_token_split_list=dense_attn_token_split_list)
def remove_paddings_after_all_gather(
hidden_states: torch.Tensor,
padding_to_token_num: int,
token_num_list: List[int],
) -> torch.Tensor:
dp_group_tensors = []
offset = 0
for token_num in token_num_list:
if token_num != 0:
dp_group_tensors.append(hidden_states[offset:offset+token_num])
offset += padding_to_token_num
if len(dp_group_tensors) == 1:
hidden_states = dp_group_tensors[0]
else:
hidden_states = torch.cat(dp_group_tensors)
return hidden_states
def tensor_model_parallel_all_gather_dp(
group_num_tokens: List[int],
rank: int,
hidden_states: Optional[torch.Tensor],
group: GroupCoordinator,
hidden_size: int = None,
dtype: torch.dtype = None,
device: torch.device = None) -> torch.Tensor:
"""
All gather in the group.
Input is a 2-D tensor, and can have different shape in the first dim,
for example, [4, 7, 5, 8], [2, 5, 4, 0].
"""
num_tokens_equal = all(x == group_num_tokens[0] for x in group_num_tokens)
if num_tokens_equal:
hidden_states = tensor_model_parallel_all_gather(
input_=hidden_states, dim=0, tp_group=group)
else:
max_num_tokens = max(group_num_tokens)
num_padding = max_num_tokens - group_num_tokens[rank]
if num_padding > 0:
if hidden_states is None:
hidden_states = torch.empty((max_num_tokens, hidden_size),
dtype=dtype, device=device)
else:
hidden_states = F.pad(hidden_states, (0, 0, 0, num_padding))
hidden_states = tensor_model_parallel_all_gather(
input_=hidden_states, dim=0, tp_group=group)
hidden_states = remove_paddings_after_all_gather(
hidden_states, max_num_tokens, group_num_tokens)
return hidden_states
def tensor_model_parallel_all_gather_op_v2(
input_: torch.Tensor,
dim_size_list: List[int],
group_coordinator: GroupCoordinator,
non_leading_dim_size: int,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
"""
All gather the input tensor across model parallel group with only communication ops.
Note: compared to `tensor_model_parallel_all_gather_dp`, this method supports different
sizes in the first dim, and does not involve padding operation.
"""
all_size_equal = all([dim_size == dim_size_list[0] for dim_size in dim_size_list])
output_shape = (sum(dim_size_list), non_leading_dim_size)
output = torch.empty(output_shape, device=device, dtype=dtype)
if input_ is None:
input_ = torch.empty((0, non_leading_dim_size), device=device, dtype=dtype)
if all_size_equal:
torch.distributed.all_gather_into_tensor(
output, input_, group=group_coordinator.device_group)
else:
# Note: torch.split splits the tensor into chunks. And each chunk
# is a view of the original tensor.
tensor_list = torch.split(output, dim_size_list, dim=0)
torch.distributed.all_gather(
list(tensor_list), input_, group=group_coordinator.device_group)
return output
def process_post_attention_communication(
hidden_states: Optional[torch.Tensor],
dp_params: DataParallelRuntimeParams,
hidden_size: int,
dtype: torch.dtype,
device: torch.device,
tp_group: Any = None,
):
"""
Processes distributed communication operations after attention computation.
This function performs necessary communication operations after attention computation
to ensure data synchronization across different parallel groups.
Supports two modes:
1. Tensor parallel mode: Uses tp_group for all-reduce and all-gather operations
2. Data parallel mode: Uses reduce-scatter and all-gather for global synchronization
Args:
hidden_states: Hidden states tensor after attention computation, can be None
dp_params: Data parallel runtime parameters containing token distribution and padding info
hidden_size: Dimension size of hidden states
dtype: Data type of the tensor
device: Device where the tensor is located
tp_group: Tensor parallel group, if None uses data parallel mode
Returns:
Hidden states tensor after communication synchronization processing
Note:
- When prefill_pad_to_token_num != -1, padding and unpadding operations will be performed
- Function selects optimal communication path based on token count and parallel strategy
"""
if tp_group is not None:
if dp_params.token_num != 0:
hidden_states = tensor_model_parallel_all_reduce(
hidden_states)
hidden_states = tensor_model_parallel_all_gather_dp(
group_num_tokens=dp_params.dense_attn_token_split_list,
rank=get_parallel_rank_with_group(tp_group),
hidden_states=hidden_states,
group=tp_group,
)
else:
if dp_params.prefill_pad_to_token_num != -1:
# pad hidden_states to use reduce_scatter and global all gather
pad_num = dp_params.prefill_pad_to_token_num - dp_params.token_num
if pad_num != 0:
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_num))
hidden_states = tensor_model_parallel_reduce_scatter(
hidden_states, dim=0)
hidden_states = tensor_model_parallel_all_gather_dp(
group_num_tokens=dp_params.attn_token_split_list_reduce_scatter,
rank=get_tp_world_rank(),
hidden_states=hidden_states,
group=get_tp_world_group(),
)
# get origin hidden_states for moe compute
hidden_states = remove_paddings_after_all_gather(
hidden_states, dp_params.prefill_pad_to_token_num,
dp_params.token_split_list)
else:
hidden_states = tensor_model_parallel_all_reduce(
hidden_states)
all_gather_group = get_dp_group()
all_gather_rank = get_data_parallel_group_rank()
hidden_states = tensor_model_parallel_all_gather_dp(
dp_params.token_split_list, all_gather_rank, hidden_states,
all_gather_group, hidden_size, dtype, device)
return hidden_states
def dp_model_forward(
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor],
dp_params: DataParallelRuntimeParams,
embedding_layer: nn.Module,
model_norm_layer: nn.Module,
start_layer: int,
end_layer: int,
layers: List[nn.Module],
layer_input_norm_name: str,
prefill_dispatch_use_RS_AG: bool,
streams: Optional[Dict[str, torch.mlu.Stream]] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
"""run model with dp."""
if dp_params is None:
dp_params = get_dp_metadata(positions.numel(),
get_data_parallel_group_world_size(),
get_data_parallel_group_rank(),
get_tensor_model_parallel_world_size(),
prefill_dispatch_use_RS_AG)
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
if embedding_layer.__class__.__name__ == "DPVocabParallelEmbedding":
hidden_states = embedding_layer(input_ids, dp_params=dp_params)
else:
hidden_states = embedding_layer(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(start_layer, end_layer):
is_first_layer = (i == start_layer)
is_last_layer = (i == end_layer - 1)
next_input_layernorm = None
if not is_last_layer:
next_input_layernorm = getattr(layers[i+1], layer_input_norm_name)
hidden_states, residual = layers[i](
positions=positions,
hidden_states=hidden_states,
residual=residual,
dp_params=dp_params,
is_first_layer=is_first_layer,
is_last_layer=is_last_layer,
streams=streams,
next_input_layernorm=next_input_layernorm,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states = model_norm_layer(hidden_states)
return hidden_states
def dp_layer_forward(
input_norm: nn.Module,
self_attn: nn.Module,
post_norm: nn.Module,
mlp: nn.Module,
mlp_kwargs: List[Dict[str, Any]],
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
dp_params: DataParallelRuntimeParams,
hidden_size: int,
hidden_states_dtype: torch.dtype,
is_first_layer: bool = False,
is_last_layer: bool = False,
next_input_layernorm: Optional[nn.Module] = None,
enable_all2all: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
run layer with dp. dispatch all2all or rs+ag or common.
For mlp_kwargs, because all2all forward args is often different with common mlp args.
So here we decide that the mlp_kwargs[-1] is always all2all kwargs. For example:
Deepseek enable all2all, mlp_kwargs will be: [{mlp common forward kwargs}, {mlp all2all kwargs}].
Deepseek does not enable all2all, mlp_kwargs will be: [{mlp common forward kwargs}].
"""
if dp_params.layer_use_reduce_scatter:
common_metadata = get_common_metadata()
is_decode_only = common_metadata is not None and common_metadata.is_decode_only
use_all2all = enable_all2all and is_decode_only and isinstance(mlp, SparseMoeMlp)
forward_func = _dp_forward_layer_all2all if use_all2all else _dp_forward_layer_rs_ag
hidden_states, residual = forward_func(input_norm,
self_attn,
post_norm,
mlp,
mlp_kwargs,
positions,
hidden_states,
residual,
dp_params,
is_first_layer,
is_last_layer,
next_input_layernorm)
else:
hidden_states, residual = _dp_forward_layer_common(input_norm,
self_attn,
post_norm,
mlp,
mlp_kwargs,
positions,
hidden_states,
residual,
dp_params,
hidden_size,
hidden_states_dtype)
return hidden_states, residual
def _dp_forward_layer_rs_ag(
input_norm: nn.Module,
self_attn: nn.Module,
post_norm: nn.Module,
mlp: nn.Module,
mlp_kwargs: List[Dict[str, Any]],
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
dp_params: DataParallelRuntimeParams,
is_first_layer: bool,
is_last_layer: bool,
next_input_layernorm: List[Optional[nn.Module]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""run layer with rs+ag."""
if residual is None:
residual = hidden_states
# We move the input_layernorm of i+1 layer to the end of i layer.
# But for the first layer, we need to do input_layernorm first.
if is_first_layer:
hidden_states = input_norm(hidden_states)
# Self Attention
hidden_states = self_attn(
positions=positions,
hidden_states=hidden_states,
)
# add residual here for the first layer
if is_first_layer and get_tensor_model_parallel_rank() == 0:
hidden_states = hidden_states + residual
hidden_states = tensor_model_parallel_reduce_scatter(
hidden_states, dim=0)
# move norm between rs and ag
if is_first_layer:
residual = hidden_states
hidden_states = post_norm(hidden_states)
else:
hidden_states, residual = post_norm(hidden_states, residual)
hidden_states = tensor_model_parallel_all_gather_dp(
group_num_tokens=dp_params.attn_token_split_list_reduce_scatter,
rank=get_tp_world_rank(),
hidden_states=hidden_states,
group=get_tp_world_group(),
)
# mlp, use all cards
hidden_states = mlp(hidden_states, **mlp_kwargs[0])
hidden_states = tensor_model_parallel_reduce_scatter(
hidden_states, dim=0, tp_group=get_tp_world_group())
if is_last_layer:
hidden_states = hidden_states + residual
residual = None
else:
# To reduce layernorm computation, we move the layernorm of i+1 layer to
# the end of i layer. Besides, we fuse residual addition into layernorm.
assert next_input_layernorm is not None
hidden_states, residual = next_input_layernorm(hidden_states, residual)
hidden_states = tensor_model_parallel_all_gather_dp(
group_num_tokens=dp_params.moe_token_split_list_reduce_scatter,
rank=get_tensor_model_parallel_rank(),
hidden_states=hidden_states,
group=get_tp_group(),
)
return hidden_states, residual
def _dp_forward_layer_all2all(
input_norm: nn.Module,
self_attn: nn.Module,
post_norm: nn.Module,
mlp: nn.Module,
mlp_kwargs: List[Dict[str, Any]],
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
dp_params: DataParallelRuntimeParams,
is_first_layer: bool,
is_last_layer: bool,
next_input_layernorm: List[Optional[nn.Module]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""run layer with all2all."""
if residual is None:
residual = hidden_states
# We move the input_layernorm of i+1 layer to the end of i layer.
# But for the first layer, we need to do input_layernorm first.
if is_first_layer:
hidden_states = input_norm(hidden_states)
# Self Attention
hidden_states = self_attn(
positions=positions,
hidden_states=hidden_states,
)
# add residual here for the first layer
if is_first_layer and get_tensor_model_parallel_rank() == 0:
hidden_states = hidden_states + residual
hidden_states = tensor_model_parallel_reduce_scatter(
hidden_states, dim=0)
# move norm between rs and ag
if is_first_layer:
residual = hidden_states
hidden_states = post_norm(hidden_states)
else:
# add residual in norm for other layers
hidden_states, residual = post_norm(hidden_states, residual)
hidden_states = mlp.forward_all2all(hidden_states, **mlp_kwargs[-1])
if is_last_layer:
hidden_states = hidden_states + residual
residual = None
else:
# To reduce layernorm computation, we move the layernorm of i+1 layer to
# the end of i layer. Besides, we fuse residual addition into layernorm.
assert next_input_layernorm is not None
hidden_states, residual = next_input_layernorm(hidden_states, residual)
hidden_states = tensor_model_parallel_all_gather_dp(
group_num_tokens=dp_params.moe_token_split_list_reduce_scatter,
rank=get_tensor_model_parallel_rank(),
hidden_states=hidden_states,
group=get_tp_group(),
)
return hidden_states, residual
def _dp_forward_layer_common(
input_norm: nn.Module,
self_attn: nn.Module,
post_norm: nn.Module,
mlp: nn.Module,
mlp_kwargs: List[Dict[str, Any]],
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
dp_params: DataParallelRuntimeParams,
hidden_size: int,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""run layer with common."""
if residual is None:
residual = hidden_states
hidden_states = input_norm(hidden_states)
hidden_states = self_attn(
positions=positions,
hidden_states=hidden_states,
)
# add residual here
if get_tensor_model_parallel_rank() == 0:
hidden_states = hidden_states + residual
hidden_states = process_post_attention_communication(
hidden_states, dp_params, hidden_size, dtype, positions.device, None
)
residual = hidden_states[dp_params.token_num_offset:
dp_params.token_num_offset + dp_params.token_num]
hidden_states = post_norm(hidden_states)
hidden_states = mlp(hidden_states, **mlp_kwargs[0])
hidden_states = tensor_model_parallel_all_reduce(
hidden_states, tp_group=get_tp_world_group())
# add residual here
hidden_states = hidden_states[dp_params.token_num_offset:
dp_params.token_num_offset+dp_params.token_num]
hidden_states = hidden_states + residual
residual = hidden_states
return hidden_states, residual

View File

@@ -0,0 +1,245 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from typing import Callable, Optional, List, Union, Tuple
from vllm_mlu import _mlu_ops as mlu_ops
from vllm.attention import AttentionMetadata
from vllm.sequence import IntermediateTensors
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from transformers import PretrainedConfig
def hunyuan_decoder_layer_forward_base(
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_layernorm: Callable,
self_attn: Callable,
post_layernorm: Callable,
mlp: Callable,
kv_states: Optional[Tuple[torch.Tensor]] = None,
apply_residual_connection_post_layernorm: bool = False,
position_name: str = 'positions',
input_norm_fuse_en: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
smooth_quant_scale = None
if input_norm_fuse_en:
layernorm_output, smooth_quant_scale = input_layernorm(hidden_states)
else:
layernorm_output = input_layernorm(hidden_states)
smooth_quant_scale = None
if apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# Self Attention
attention_output, ori_kv_states = self_attn(
**{position_name: positions},
hidden_states=layernorm_output,
residual=residual,
kv_states=kv_states,
smooth_quant_scale=smooth_quant_scale,
)
layernorm_output = post_layernorm(attention_output)
if apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = attention_output
# Fully Connected
hidden_states = mlp(layernorm_output, residual)
return hidden_states, ori_kv_states
def decoder_layer_forward_base(
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_layernorm: Callable,
self_attn: Callable,
post_layernorm: Callable,
mlp: Callable,
apply_residual_connection_post_layernorm: bool = False,
position_name: str = 'positions',
input_norm_fuse_en: bool = False,
post_norm_fuse_en: bool = False,
) -> torch.Tensor:
if input_norm_fuse_en:
layernorm_output, smooth_quant_scale = input_layernorm(hidden_states)
else:
layernorm_output = input_layernorm(hidden_states)
smooth_quant_scale = None
if apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# Self Attention
attention_output = self_attn(
**{position_name: positions},
hidden_states=layernorm_output,
residual=residual,
smooth_quant_scale=smooth_quant_scale,
)
if post_norm_fuse_en:
layernorm_output, smooth_quant_scale = post_layernorm(attention_output)
else:
layernorm_output = post_layernorm(attention_output)
smooth_quant_scale = None
if apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = attention_output
# Fully Connected
kwargs = dict()
if post_norm_fuse_en:
kwargs['smooth_quant_scale'] = smooth_quant_scale
hidden_states = mlp(layernorm_output, residual, **kwargs)
return hidden_states
def decoder_model_forward_base(
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
layers: torch.nn.ModuleList,
embed_input_ids: Callable,
norm: Callable
) -> torch.Tensor:
hidden_states = embed_input_ids(input_ids)
for i in range(len(layers)):
layer = layers[i]
hidden_states = layer(
positions,
hidden_states,
)
hidden_states = norm(hidden_states)
return hidden_states
def hunyuan_decoder_model_forward_base_pp(
config: PretrainedConfig,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
layers: torch.nn.ModuleList,
start_layer: int,
end_layer: int,
embed_input_ids: Callable,
norm: Callable,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = embed_input_ids(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
cla_factor = getattr(config, "cla_share_factor", 1)
prev_kv_states = None
for i in range(start_layer, end_layer):
layer = layers[i]
hidden_states, kv_states = layer(
positions,
hidden_states,
prev_kv_states,
)
if (i - start_layer) % cla_factor == 0:
prev_kv_states = kv_states
else:
prev_kv_states = None
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
})
hidden_states = norm(hidden_states)
return hidden_states
def decoder_model_forward_base_pp(
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
layers: torch.nn.ModuleList,
start_layer: int,
end_layer: int,
embed_input_ids: Callable,
norm: Callable,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = embed_input_ids(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(start_layer, end_layer):
layer = layers[i]
hidden_states = layer(
positions,
hidden_states,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
})
hidden_states = norm(hidden_states)
return hidden_states
def is_smoothquant(quant_config: QuantizationConfig) -> bool:
return (quant_config is not None
and quant_config.get_name() == "SmoothQuant")
def is_per_token_smoothquant(quant_config: QuantizationConfig) -> bool:
return (is_smoothquant(quant_config)
and quant_config.input_quant_method == "per_token")
def compute_in_loop(func: Callable,
input: torch.Tensor,
chunk_size: int,
feature_size: Optional[int] = None,
**kwargs):
"""
divides input into chunks in the leading dimension (dimension 0), and
compute the chunks in a loop, instead of in a batch at once.
arg:
feature_size: size of output feature dimension. Provide it when the
the output's feature dimension would differ from the input's
feature dimension.
"""
total = input.shape[0]
# directly compute if there is only one chunk
if chunk_size >= total:
return func(input, **kwargs)
feature_size = feature_size or input.shape[1]
output = input.new_empty(total, feature_size)
num_chunks = (total + chunk_size - 1) // chunk_size
for i in range(num_chunks):
start = i * chunk_size
end = min((i + 1) * chunk_size, total)
output[start : end] = func(input[start : end], **kwargs)
return output

View File

@@ -0,0 +1,507 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import itertools
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm_mlu.mlu_forward_context import MLUDPMetadata
from vllm_mlu.model_executor.models.dp_utils import DataParallelRuntimeParams
from vllm_mlu.v1.attention.backends.mla.flashmla import (
FlashMLAPrefillMetadata, FlashMLAMetadata, MLACommonMetadata
)
from vllm_mlu.v1.attention.backends.utils import (
COMMON_METADATA_STR,
MLUCommonAttentionMetadata,
)
SEQUENCE_DIM_PARITION_THRESHOLD = 1024
def get_common_and_layer_metadata(
attn_metadata: Optional[dict],
) -> Tuple[Optional[MLUCommonAttentionMetadata], Optional[AttentionMetadata]]:
"""
Returns the common metadata and layer metadata from the given attention metadata.
"""
if attn_metadata is None:
return None, None
if isinstance(attn_metadata, dict):
assert COMMON_METADATA_STR in attn_metadata, (
f"attn_metadata must contain {COMMON_METADATA_STR} key"
)
assert len({id(v) for v in attn_metadata.values()}) == 2, (
f"attn_metadata should be a dict with two values, one for {COMMON_METADATA_STR} and "
f"the other for layers."
)
common_metadata = attn_metadata[COMMON_METADATA_STR]
layer_metadata = next((v for k, v in attn_metadata.items() if k != COMMON_METADATA_STR), None)
return common_metadata, layer_metadata
return None, attn_metadata
def should_skip_partition(layer_metadata, common_metadata) -> bool:
"""Helper function to simplify partition condition check"""
is_layer_metadata_invalid = (layer_metadata is None
or layer_metadata.prefill is None
or layer_metadata.query_start_loc is None
or layer_metadata.query_start_loc.numel() == 0)
is_common_metadata_invalid = common_metadata is None or not common_metadata.is_prefill_only
return is_layer_metadata_invalid or is_common_metadata_invalid
def attn_mcc_plan(
attn_metadata: Any,
dp_params: DataParallelRuntimeParams,
parts_to_split: int,
) -> Tuple[int, int]:
"""
Returns the number of parts for batch size dimension and the number of parts for sequence length dimension.
"""
# In the precedure of dummy run, attn_metadata is an instance of MLACommonMetadata
if not isinstance(attn_metadata, (dict, MLACommonMetadata, type(None))):
raise TypeError(f"attn_metadata must be dict or MLACommonMetadata, got {type(attn_metadata)}")
if isinstance(attn_metadata, dict):
common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata)
else:
common_metadata, layer_metadata = None, attn_metadata
if dp_params is None:
# We don't support mcc with decode yet.
if should_skip_partition(layer_metadata, common_metadata):
return 1, 1
# The priority of batch size dimension to split is higher than sequence length dimension.
# And we ensure each subtask is not empty without dp.
num_prefills = layer_metadata.query_start_loc.numel() - 1
if num_prefills > 1:
return min(parts_to_split, num_prefills), 1
try:
max_query_len = torch.diff(layer_metadata.query_start_loc).max().item()
except RuntimeError:
return 1, 1
if max_query_len < SEQUENCE_DIM_PARITION_THRESHOLD:
return 1, 1
return 1, min(parts_to_split, max_query_len)
else:
if not all(is_prefill for is_prefill in dp_params.dp_is_prefill):
return 1, 1
max_bs = max(dp_params.batch_sizes)
if max_bs > 1:
# Ensure parts_to_split does not exceed max_bs to avoid unnecessary splits
if max(dp_params.token_split_list) < SEQUENCE_DIM_PARITION_THRESHOLD:
return 1, 1
return min(parts_to_split, max_bs), 1
else:
if max(dp_params.token_split_list) < SEQUENCE_DIM_PARITION_THRESHOLD:
return 1, 1
return 1, parts_to_split
def get_data_num_and_offset(total_size, parts_to_split):
"""
Get data size and offset for each.
For example, total batch 11, parallel_num 4, result is [3, 3, 3, 2], offsets is [0, 3, 6, 9]
total batch 8, parallel_num 4, result is [2, 2, 2, 2], offsets is [0, 2, 4, 6]
"""
# Calculate the quotient and remainder of total_size divided by parts_to_split
quotient = total_size // parts_to_split
remainder = total_size % parts_to_split
data_num_list = [quotient + 1] * remainder + [quotient] * (parts_to_split - remainder)
offset_list = [0] + list(itertools.accumulate(data_num_list))
return data_num_list, offset_list[:-1]
def split_dp_params(
dp_params: DataParallelRuntimeParams,
bs_parts_to_split: int,
seq_parts_to_split: int,
attn_data_parallel_size: int,
attn_tensor_parallel_size: int,
prefill_dispatch_use_RS_AG: bool,
dp_rank_: int,
) -> List[DataParallelRuntimeParams]:
assert bs_parts_to_split == 1 or seq_parts_to_split == 1, \
"We don't support split batch and sequence dimensions concurrently."
if dp_params is None:
return [None] * bs_parts_to_split * seq_parts_to_split
if bs_parts_to_split * seq_parts_to_split == 1:
return list([dp_params])
if bs_parts_to_split == 1:
results : List[DataParallelRuntimeParams] = []
dp_seq_lens = []
for seq_len in dp_params.seq_lens:
tokens, _ = get_data_num_and_offset(seq_len, seq_parts_to_split)
dp_seq_lens.append(tokens)
query_lens_per_dp_rank = []
# For each dp rank, the batch size is 0 or 1.
bs_offset = 0
for i in range(attn_data_parallel_size):
if dp_params.batch_sizes[i] > 0:
seq_len = dp_params.seq_lens[bs_offset]
tokens, _ = get_data_num_and_offset(seq_len, seq_parts_to_split)
query_lens_per_dp_rank.append(tokens)
bs_offset += dp_params.batch_sizes[i]
else:
query_lens_per_dp_rank.append([0] * seq_parts_to_split)
for i in range(seq_parts_to_split):
dp_is_prefill = []
for dp_rank in range(attn_data_parallel_size):
dp_is_prefill.append(True)
results.append(MLUDPMetadata.make_oot(
data_parallel_rank=dp_rank_,
data_parallel_size=attn_data_parallel_size,
tensor_parallel_size=attn_tensor_parallel_size,
dp_token_nums=[query_lens_per_dp_rank[j][i] for j in range(attn_data_parallel_size)],
dp_is_prefill=dp_is_prefill,
prefill_dispatch_use_RS_AG=prefill_dispatch_use_RS_AG,
seq_lens=[seq_lens[i] for seq_lens in dp_seq_lens],
batch_sizes=dp_params.batch_sizes,
))
return results
bs_per_dp = dp_params.batch_sizes # [bs_rank_0, bs_rank_1, ...]
seq_lens_per_dp = dp_params.seq_lens # [seq_len_bs_0, seq_len_bs_1,...]
# [[bs_rank_0_part_0, bs_rank_0_part_1,...], [bs_rank_1_part_0, bs_rank_1_part_1,...], ...]
split_bs_per_dp = []
# [[
# [bs0_part_0_rank_0, bs1_part_0_rank_0, ...],
# [bs0_part_1_rank_0, bs1_part_1_rank_0, ...],
# ...
# ],
# [
# [bs0_part_0_rank_1, bs1_part_0_rank_1, ...],
# [bs0_part_1_rank_1, bs1_part_1_rank_1, ...],
# ...
# ],
# ]
split_query_lens_per_dp = []
for dp_rank in range(attn_data_parallel_size):
_bs, _offset = get_data_num_and_offset(bs_per_dp[dp_rank], bs_parts_to_split)
split_bs_per_dp.append(_bs)
split_query_lens_per_dp.append([])
for i in range(bs_parts_to_split):
start = sum(bs_per_dp[:dp_rank]) + _offset[i]
end = start + _bs[i]
split_query_lens_per_dp[-1].append(dp_params.seq_lens[start:end])
results : List[DataParallelRuntimeParams] = []
for i in range(bs_parts_to_split):
dp_query_lens = [sum(split_query_lens_per_dp[dp_rank][i]) for dp_rank in range(attn_data_parallel_size)]
seq_lens = []
for dp_rank in range(attn_data_parallel_size):
seq_lens += split_query_lens_per_dp[dp_rank][i]
batch_sizes = []
for dp_rank in range(attn_data_parallel_size):
batch_sizes.append(split_bs_per_dp[dp_rank][i])
dp_is_prefill = []
for dp_rank in range(attn_data_parallel_size):
dp_is_prefill.append(True)
results.append(MLUDPMetadata.make_oot(
data_parallel_rank=dp_rank_,
data_parallel_size=attn_data_parallel_size,
tensor_parallel_size=attn_tensor_parallel_size,
dp_token_nums=dp_query_lens,
dp_is_prefill=dp_is_prefill,
prefill_dispatch_use_RS_AG=prefill_dispatch_use_RS_AG,
seq_lens=seq_lens,
batch_sizes=batch_sizes,
))
return results
def split_input(
input: torch.Tensor,
bs_parts_to_split: int,
seq_parts_to_split: int,
attn_metadata_list: List[AttentionMetadata],
) -> List[torch.Tensor]:
assert seq_parts_to_split == 1 or bs_parts_to_split == 1, \
"We don't support split batch and sequence dimensions concurrently."
if input is None:
return [None] * bs_parts_to_split * seq_parts_to_split
if bs_parts_to_split * seq_parts_to_split == 1:
return list([input])
token_num_list = [0] * len(attn_metadata_list)
for i, metadata in enumerate(attn_metadata_list):
common_metadata, layer_metadata = get_common_and_layer_metadata(metadata)
if layer_metadata is not None:
token_num_list[i] = layer_metadata.num_actual_tokens
# A special case for dummy run
if layer_metadata is None and i == 0:
token_num_list[i] = input.shape[0]
results = list()
for i in range(bs_parts_to_split * seq_parts_to_split):
start = sum(token_num_list[:i])
end = start + token_num_list[i]
results.append(input[start:end])
return results
def split_positions(
positions: torch.Tensor,
bs_parts_to_split: int,
seq_parts_to_split: int,
attn_metadata: AttentionMetadata,
) -> List[torch.Tensor]:
if seq_parts_to_split == 1:
return [positions] * bs_parts_to_split
common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata)
total_tokens = layer_metadata.num_actual_tokens if layer_metadata is not None else 0
tokens, offsets = get_data_num_and_offset(total_tokens, seq_parts_to_split)
positions_list = []
for i in range(seq_parts_to_split):
positions_list.append(positions[offsets[i]: offsets[i] + tokens[i]])
return positions_list
def split_attn_metadata(
attn_metadata: dict,
bs_parts_to_split: int,
seq_parts_to_split: int,
) -> List[Any]:
""" attn_metdata is a dict, which contains common and layer metadata."""
assert bs_parts_to_split == 1 or seq_parts_to_split == 1, \
"We don't support split batch and sequence dimensions concurrently."
if bs_parts_to_split == 1 and seq_parts_to_split == 1:
return list([attn_metadata])
if attn_metadata is None:
return [None] * bs_parts_to_split * seq_parts_to_split
if seq_parts_to_split > 1:
common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata)
if common_metadata is None or not hasattr(common_metadata, 'num_actual_tokens'):
raise ValueError("common_metadata is invalid or missing num_actual_tokens")
num_prefill_tokens = common_metadata.num_actual_tokens
tokens, offsets = get_data_num_and_offset(num_prefill_tokens, seq_parts_to_split)
device = common_metadata.seq_lens.device
sub_common_metadata, sub_layer_metadata = [], []
for i in range(seq_parts_to_split):
# query_start_loc tensor, which indices positions in input.
query_start_loc_tensor = torch.empty_like(common_metadata.query_start_loc)
query_start_loc_tensor[0] = 0
query_start_loc_tensor[1] = tokens[i]
# seq_lens tensor
seq_lens_tensor = torch.tensor(
[offsets[i] + tokens[i]],
dtype=common_metadata.seq_lens.dtype,
device=device
)
# seq_start_loc tensor, which indicates positions in the sequence(kv cache).
seq_start_loc_tensor = torch.empty_like(common_metadata.seq_start_loc)
seq_start_loc_tensor[0] = offsets[i]
seq_start_loc_tensor[1] = offsets[i] + tokens[i]
# max_query_len scalar
max_query_len = tokens[i]
# num_actual_tokens scalar
num_actual_tokens = tokens[i]
# num_input_tokens scalar
num_input_tokens = num_actual_tokens
# infer_mode
infer_mode = common_metadata.infer_mode
# update common metadata
sub_common_metadata.append(MLUCommonAttentionMetadata(
query_start_loc=query_start_loc_tensor,
query_start_loc_cpu=common_metadata.query_start_loc_cpu, # FIXME: split when used
seq_lens=seq_lens_tensor,
seq_lens_cpu=common_metadata.seq_lens_cpu, # FIXME: split when used
num_computed_tokens_cpu=common_metadata.num_computed_tokens_cpu, # FIXME: split when used
num_reqs=common_metadata.num_reqs, # FIXME: split when used
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
max_seq_len=max_query_len,
block_table_tensor=common_metadata.block_table_tensor, # FIXME: split when used
slot_mapping=common_metadata.slot_mapping, # FIXME: split when used
seq_start_loc=seq_start_loc_tensor,
num_input_tokens=num_input_tokens,
infer_mode=infer_mode,
num_prefill_query_tokens=tokens[i],
num_prefill_kv_tokens=offsets[i] + tokens[i],
))
# slot_mapping tensor
slot_mapping = layer_metadata.slot_mapping[offsets[i]:offsets[i] + tokens[i]]
# update layer metadata
REQUIRED_NUM_DECODES = 0
REQUIRED_NUM_DECODE_TOKENS = 0
REQUIRED_NUM_PREFILLS = 1
if not hasattr(layer_metadata, 'num_prefills') or \
layer_metadata.num_prefills is None:
raise ValueError("layer_metadata.num_prefills is required")
assert layer_metadata.num_decodes == REQUIRED_NUM_DECODES and \
layer_metadata.num_decode_tokens == REQUIRED_NUM_DECODE_TOKENS and \
layer_metadata.num_prefills == REQUIRED_NUM_PREFILLS, (
f"num_decodes, num_decode_tokens, num_prefills must be {REQUIRED_NUM_DECODES}, {REQUIRED_NUM_DECODE_TOKENS}, "
f"{REQUIRED_NUM_PREFILLS}, but got {layer_metadata.num_decodes}, {layer_metadata.num_decode_tokens}, "
f"{layer_metadata.num_prefills}."
)
assert layer_metadata.prefill.chunked_context is None, (
f"chunked_context is only available for prefill with chunked context, "
f"and it is not supported when enabling mcc."
)
prefill_metadata = FlashMLAPrefillMetadata(
block_table=layer_metadata.prefill.block_table,
query_start_loc=query_start_loc_tensor,
max_query_len=max_query_len,
chunked_context=None,
num_prefills=layer_metadata.prefill.num_prefills,
max_seq_len=layer_metadata.prefill.max_seq_len,
)
# Note: for sequence dimension partition, we provide cu_seqlens_kv filed to
# indicates key/value size for flash attention operator.
prefill_metadata.cu_seqlens_kv = torch.empty_like(prefill_metadata.query_start_loc)
prefill_metadata.cu_seqlens_kv[0] = 0
prefill_metadata.cu_seqlens_kv[1] = offsets[i] + tokens[i]
sub_layer_metadata.append(FlashMLAMetadata(
num_reqs=layer_metadata.num_reqs,
max_query_len=max_query_len,
max_seq_len=max_query_len,
num_actual_tokens=num_actual_tokens,
query_start_loc=query_start_loc_tensor,
slot_mapping=slot_mapping,
num_decodes=layer_metadata.num_decodes,
num_decode_tokens=layer_metadata.num_decode_tokens,
num_prefills=layer_metadata.num_prefills,
num_prefill_tokens=tokens[i],
head_dim=layer_metadata.head_dim,
decode=layer_metadata.decode,
prefill=prefill_metadata,
))
sub_attn_metadata_list = []
for common_meta, layer_meta in zip(sub_common_metadata, sub_layer_metadata):
sub_attn_metadata_dict = {}
for key, value in attn_metadata.items():
if key == COMMON_METADATA_STR:
sub_attn_metadata_dict[key] = common_meta
else:
sub_attn_metadata_dict[key] = layer_meta
sub_attn_metadata_list.append(sub_attn_metadata_dict)
return sub_attn_metadata_list
elif bs_parts_to_split > 1:
common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata)
if not hasattr(layer_metadata, 'num_prefills') or layer_metadata.num_prefills is None:
raise ValueError("layer_metadata.num_prefills is required")
total_batch = layer_metadata.num_prefills
batch_sizes, offsets = get_data_num_and_offset(total_batch, bs_parts_to_split)
sub_common_metadata, sub_layer_metadata = [], []
for i in range(bs_parts_to_split):
# query_start_loc tensor
start, end = offsets[i], offsets[i] + batch_sizes[i]
query_start_loc_tensor = common_metadata.query_start_loc[start:end+1].clone()
if i > 0:
query_start_loc_tensor -= common_metadata.query_start_loc[start]
# block_table
block_tables = torch.empty(
(batch_sizes[i], 0),
dtype=layer_metadata.prefill.block_table.dtype,
device=layer_metadata.prefill.block_table.device,
)
# seq_lens tensor
seq_lens_tensor = common_metadata.seq_lens[start:end].clone()
# seq_start_loc tensor
seq_start_loc_tensor = query_start_loc_tensor
# max_query_len scalar
max_query_len = seq_lens_tensor.max().item() if seq_lens_tensor.numel() > 0 else 0
# num_actual_tokens scalar
num_actual_tokens = seq_start_loc_tensor[-1].item()
# num_input_tokens scalar
num_input_tokens = num_actual_tokens
# infer_mode
infer_mode = common_metadata.infer_mode
# slot_mapping tensor
slot_mapping_start = 0
for data in sub_common_metadata:
slot_mapping_start += data.num_actual_tokens
slot_mapping_tensor = layer_metadata.slot_mapping[
slot_mapping_start:slot_mapping_start + num_actual_tokens
]
# update common metadata
sub_common_metadata.append(MLUCommonAttentionMetadata(
query_start_loc=query_start_loc_tensor,
query_start_loc_cpu=common_metadata.query_start_loc_cpu, # FIXME: split when used
seq_lens=seq_lens_tensor,
seq_lens_cpu=common_metadata.seq_lens_cpu, # FIXME: split when used
num_computed_tokens_cpu=common_metadata.num_computed_tokens_cpu, # FIXME: split when used
num_reqs=common_metadata.num_reqs, # FIXME: split when used
block_table_tensor=common_metadata.block_table_tensor, # FIXME: split when used
slot_mapping=common_metadata.slot_mapping, # FIXME: split when used
seq_start_loc=seq_start_loc_tensor,
max_query_len=max_query_len,
max_seq_len=max_query_len,
num_actual_tokens=num_actual_tokens,
num_input_tokens=num_input_tokens,
infer_mode=infer_mode,
num_prefill_query_tokens=num_actual_tokens,
num_prefill_kv_tokens=num_actual_tokens,
))
# update layer_metadata
prefill_metadata = FlashMLAPrefillMetadata(
block_table=block_tables,
query_start_loc=query_start_loc_tensor,
max_query_len=max_query_len,
chunked_context=None,
num_prefills=batch_sizes[i],
max_seq_len=max_query_len,
)
sub_layer_metadata.append(FlashMLAMetadata(
num_reqs=batch_sizes[i],
max_query_len=max_query_len,
max_seq_len=max_query_len,
num_actual_tokens=num_actual_tokens,
query_start_loc=query_start_loc_tensor,
slot_mapping=slot_mapping_tensor,
num_decodes=layer_metadata.num_decodes, # useless field
num_decode_tokens=0, # useless field
num_prefills=batch_sizes[i],
num_prefill_tokens=num_actual_tokens,
head_dim=layer_metadata.head_dim,
decode=layer_metadata.decode,
prefill=prefill_metadata,
))
sub_attn_metadata_list = []
for common_meta, layer_meta in zip(sub_common_metadata, sub_layer_metadata):
sub_attn_metadata_dict = {}
for key, value in attn_metadata.items():
if key == COMMON_METADATA_STR:
sub_attn_metadata_dict[key] = common_meta
else:
sub_attn_metadata_dict[key] = layer_meta
sub_attn_metadata_list.append(sub_attn_metadata_dict)
return sub_attn_metadata_list
def execute_with_updated_forward_context(
vllm_config: VllmConfig,
attn_metadata: AttentionMetadata,
func: Callable,
kwargs: Dict[str, Any],
):
with set_forward_context(attn_metadata, vllm_config):
return func(**kwargs)

View File

@@ -0,0 +1,81 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Type, Union
import torch.nn as nn
from vllm.logger import init_logger
from vllm.model_executor.models.registry import (
_LazyRegisteredModel, _RegisteredModel, _ModelRegistry)
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
def vllm__model_executor__models__registry___ModelRegistry__register_model(
self,
model_arch: str,
model_cls: Union[type[nn.Module], str],
) -> None:
"""
Register an external model to be used in vLLM.
`model_cls` can be either:
- A [`torch.nn.Module`][] class directly referencing the model.
- A string in the format `<module>:<class>` which can be used to
lazily import the model. This is useful to avoid initializing CUDA
when importing the model and thus the related error
`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
"""
if not isinstance(model_arch, str):
msg = f"`model_arch` should be a string, not a {type(model_arch)}"
raise TypeError(msg)
'''
=============================
Modify by vllm_mlu
=============================
@brief: change mlu models register log level
'''
if model_arch in self.models:
if isinstance(model_cls, str) and "MLU" in model_cls:
logger.debug(
"Model architecture %s is already registered, and will be "
"overwritten by the new model class %s.", model_arch,
model_cls)
else:
logger.warning(
"Model architecture %s is already registered, and will be "
"overwritten by the new model class %s.", model_arch,
model_cls)
'''
==================
End of MLU Hijack
==================
'''
if isinstance(model_cls, str):
split_str = model_cls.split(":")
if len(split_str) != 2:
msg = "Expected a string in the format `<module>:<class>`"
raise ValueError(msg)
model = _LazyRegisteredModel(*split_str)
elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
model = _RegisteredModel.from_model_cls(model_cls)
else:
msg = ("`model_cls` should be a string or PyTorch model class, "
f"not a {type(model_arch)}")
raise TypeError(msg)
self.models[model_arch] = model
MluHijackObject.apply_hijack(
_ModelRegistry,
_ModelRegistry.register_model,
vllm__model_executor__models__registry___ModelRegistry__register_model
)

View File

@@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import json
import os
import torch
import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.forward_context import get_forward_context
def set_attn_compute_dtype_v1(attn_metadata, dtype: torch.dtype):
'''
set attn compute_dtype for v1
'''
if isinstance(attn_metadata, dict):
for _, metadata in attn_metadata.items():
metadata.compute_dtype = dtype
else:
metadata.compute_dtype = dtype
def set_attn_compute_dtype(dtype: torch.dtype):
'''
set attn compute_dtype.
TODO: FA may standardize on half precision computation in the future
set_attn_compute_dtype might be deprecated and removed
'''
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
set_attn_compute_dtype_v1(attn_metadata, dtype)
def is_tie_word_embeddings(
model_config: ModelConfig,
org_tie_word_embeddings: bool
) -> bool:
'''
Vllm language model config for multimodal model may have wrong tie_word_embeddings,
for example, InternVL3.5-38B, InternVL3.5-30B-A3B, etc.
This function is a WorkAround.
'''
from vllm.lora.utils import get_adapter_absolute_path
if not model_config.is_multimodal_model:
return org_tie_word_embeddings
model_path = get_adapter_absolute_path(model_config.model)
config_path = os.path.join(model_path, "config.json")
if not os.path.exists(config_path):
return org_tie_word_embeddings
tie_word_embeddings = org_tie_word_embeddings
with open(config_path) as f:
config = json.load(f)
# first, we find if tie_word_embeddings config is in overall config
if config.get("tie_word_embeddings") is not None:
tie_word_embeddings = config["tie_word_embeddings"]
# then, we find if tie_word_embeddings config is in language model config
if (config.get("llm_config") is not None
and config["llm_config"].get("tie_word_embeddings") is not None):
tie_word_embeddings = config["llm_config"]["tie_word_embeddings"]
return tie_word_embeddings