[Model] Support DeepSeek-V4
This commit is contained in:
12
vllm_mlu/model_executor/models/__init__.py
Executable file
12
vllm_mlu/model_executor/models/__init__.py
Executable file
@@ -0,0 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from vllm import ModelRegistry
|
||||
|
||||
|
||||
def register_model():
|
||||
from .deepseek_v4 import MLUDeepseekV4ForCausalLM # noqa: F401
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV4ForCausalLM",
|
||||
"vllm_mlu.model_executor.models.deepseek_v4:MLUDeepseekV4ForCausalLM")
|
||||
192
vllm_mlu/model_executor/models/config.py
Normal file
192
vllm_mlu/model_executor/models/config.py
Normal file
@@ -0,0 +1,192 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from math import lcm
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.model_executor.models.config import (HybridAttentionMambaModelConfig,
|
||||
MambaModelConfig)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@classmethod
|
||||
def vllm__module_executor__models__config__HybridAttentionMambaModelConfig__verify_and_update_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig"
|
||||
) -> None:
|
||||
"""
|
||||
Ensure that page size of attention layers is greater than or
|
||||
equal to the mamba layers. If not, automatically set the attention
|
||||
block size to ensure that it is. If the attention page size is
|
||||
strictly greater than the mamba page size, we pad the mamba page size
|
||||
to make them equal.
|
||||
|
||||
Args:
|
||||
vllm_config: vLLM Config
|
||||
"""
|
||||
# Save the user input before it gets modified by MambaModelConfig
|
||||
mamba_block_size = vllm_config.cache_config.mamba_block_size
|
||||
# Enable FULL_AND_PIECEWISE by default
|
||||
MambaModelConfig.verify_and_update_config(vllm_config)
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
if cache_config.cache_dtype == "auto":
|
||||
kv_cache_dtype = model_config.dtype
|
||||
else:
|
||||
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||
|
||||
# get attention page size (for 1 token)
|
||||
# Attention backend constraints:
|
||||
# - FlashAttention (FA) requires block size to be multiple of 16
|
||||
# - MLA (Multi-head Latent Attention) requires larger alignment:
|
||||
# * CUTLASS_MLA backend: kernel_block_size 128 alignment
|
||||
# * Other MLA backends: kernel_block_size 64 alignment
|
||||
if model_config.use_mla:
|
||||
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
|
||||
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
|
||||
attn_page_size_1_token = MLAAttentionSpec(
|
||||
block_size=1,
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
dtype=kv_cache_dtype,
|
||||
).page_size_bytes
|
||||
else:
|
||||
kernel_block_alignment_size = 16
|
||||
if (
|
||||
current_platform.is_device_capability(100)
|
||||
and model_config.get_head_size() == 256
|
||||
and (
|
||||
envs.VLLM_ATTENTION_BACKEND is None
|
||||
or envs.VLLM_ATTENTION_BACKEND == "FLASHINFER"
|
||||
)
|
||||
):
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that`
|
||||
# head size 256 and block size 16 is not supported on blackwell.
|
||||
kernel_block_alignment_size = 32
|
||||
attn_page_size_1_token = FullAttentionSpec(
|
||||
block_size=1,
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
dtype=kv_cache_dtype,
|
||||
).page_size_bytes
|
||||
|
||||
model_cls, _ = ModelRegistry.resolve_model_cls(
|
||||
model_config.architecture,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
# get mamba page size
|
||||
mamba_page_size = MambaSpec(
|
||||
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
|
||||
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
|
||||
block_size=model_config.max_model_len,
|
||||
).page_size_bytes
|
||||
|
||||
# Model may be marked as is_hybrid
|
||||
# but mamba is skipped via config,
|
||||
# return directly
|
||||
if mamba_page_size == 0:
|
||||
return
|
||||
|
||||
if cache_config.enable_prefix_caching:
|
||||
# With prefix caching, select attention block size to
|
||||
# optimize for mamba kernel performance
|
||||
|
||||
# Mamba2 SSD kernel uses a chunk_size, e.g. 256
|
||||
# Align the block to the kernel: use lowest multiple of chunk_size
|
||||
# of attention tokens that would fit mamba_page_size:
|
||||
# e.g. for mamba page size = 788kB
|
||||
# attn_1_token = 2kB -> fits ~394 tokens
|
||||
# then round up to a mulitple of 256 -> 512 tokens
|
||||
# End result:
|
||||
# attn_block_size = 512
|
||||
# mamba_block_size = 512 (aligned to a multiple of chunk_size)
|
||||
# TODO(tdoublep): this constraint can be relaxed fairly
|
||||
# easily by changing the way we layout chunks in the
|
||||
# mamba2 kernels.
|
||||
|
||||
base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
|
||||
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
|
||||
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
|
||||
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
|
||||
cache_config.mamba_block_size = attn_block_size
|
||||
else:
|
||||
# Without prefix caching, select minimum valid attention block size
|
||||
# to minimize mamba state padding
|
||||
|
||||
# Calculate minimum attention block size that satisfies both:
|
||||
# 1. Backend alignment requirements (kernel_block_alignment_size)
|
||||
# 2. Mamba page size compatibility (attn_page_size >= mamba_page_size)
|
||||
attn_block_size = kernel_block_alignment_size * cdiv(
|
||||
mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
|
||||
)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: support qwen3-next
|
||||
'''
|
||||
if (vllm_config.mlu_config.enable_mamba_split_page_size):
|
||||
vllm_config.mlu_config.mamba_to_attn_block_ratio = cdiv(attn_block_size, cache_config.block_size)
|
||||
cache_config.mamba_page_size_padded = cache_config.block_size * attn_page_size_1_token
|
||||
return
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
# override attention block size if either (a) the
|
||||
# user has not set it or (b) the user has set it
|
||||
# too small.
|
||||
if cache_config.block_size is None or cache_config.block_size < attn_block_size:
|
||||
cache_config.block_size = attn_block_size
|
||||
logger.info(
|
||||
"Setting attention block size to %d tokens "
|
||||
"to ensure that attention page size is >= mamba page size.",
|
||||
attn_block_size,
|
||||
)
|
||||
|
||||
# compute new attention page size
|
||||
attn_page_size = cache_config.block_size * attn_page_size_1_token
|
||||
|
||||
assert attn_page_size >= mamba_page_size
|
||||
|
||||
if attn_page_size == mamba_page_size:
|
||||
# don't need to pad mamba page size
|
||||
return
|
||||
|
||||
# pad mamba page size to exactly match attention
|
||||
if (
|
||||
cache_config.mamba_page_size_padded is None
|
||||
or cache_config.mamba_page_size_padded != attn_page_size
|
||||
):
|
||||
cache_config.mamba_page_size_padded = attn_page_size
|
||||
mamba_padding_pct = (
|
||||
100 * (attn_page_size - mamba_page_size) / mamba_page_size
|
||||
)
|
||||
logger.info(
|
||||
"Padding mamba page size by %.2f%% to ensure "
|
||||
"that mamba page size and attention page size are "
|
||||
"exactly equal.",
|
||||
mamba_padding_pct,
|
||||
)
|
||||
|
||||
MluHijackObject.apply_hijack(HybridAttentionMambaModelConfig,
|
||||
HybridAttentionMambaModelConfig.verify_and_update_config,
|
||||
vllm__module_executor__models__config__HybridAttentionMambaModelConfig__verify_and_update_config)
|
||||
1096
vllm_mlu/model_executor/models/deepseek_v4.py
Normal file
1096
vllm_mlu/model_executor/models/deepseek_v4.py
Normal file
File diff suppressed because it is too large
Load Diff
607
vllm_mlu/model_executor/models/dp_utils.py
Normal file
607
vllm_mlu/model_executor/models/dp_utils.py
Normal file
@@ -0,0 +1,607 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import (
|
||||
Any, List, Tuple, Optional, Dict, Union, ClassVar, Literal,
|
||||
Protocol, overload, runtime_checkable)
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.communication_op import (
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_gather_into_list,
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_reduce_scatter,
|
||||
)
|
||||
from vllm.distributed import (
|
||||
get_tp_group,
|
||||
get_pp_group,
|
||||
get_dp_group,
|
||||
get_data_parallel_group_rank,
|
||||
get_data_parallel_group_world_size,
|
||||
get_dense_mlp_tp_world_size,
|
||||
get_tp_world_world_size,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_logits_tp_world_size,
|
||||
get_parallel_rank_with_group,
|
||||
get_tp_world_group,
|
||||
get_tp_world_rank,
|
||||
GroupCoordinator,
|
||||
)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_mlu.mlu_forward_context import MLUDPMetadata
|
||||
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
|
||||
from vllm_mlu.v1.attention.backends.utils import get_common_metadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# alias after refactor
|
||||
DataParallelRuntimeParams = MLUDPMetadata
|
||||
|
||||
|
||||
def enable_data_parallel():
|
||||
return get_dp_group().world_size > 1
|
||||
|
||||
|
||||
def enable_emb_logits_custom_parallel():
|
||||
return get_logits_tp_world_size() != get_tensor_model_parallel_world_size()
|
||||
|
||||
|
||||
def enable_dense_mlp_custom_parallel():
|
||||
return get_dense_mlp_tp_world_size() != get_tp_world_world_size()
|
||||
|
||||
|
||||
def get_runtime_infos_per_dp_group(
|
||||
num_tokens: int, num_requests: int, all_prefill: bool, seq_lens: List[int],
|
||||
device: torch.device, vllm_config: VllmConfig) -> Tuple[List[int], List[bool]]:
|
||||
dp_tensor = torch.tensor([num_tokens, num_requests, int(all_prefill)]).to(device, non_blocking=True)
|
||||
outputs = tensor_model_parallel_all_gather_into_list(dp_tensor, get_dp_group())
|
||||
outputs = torch.cat(outputs).tolist() # d2h
|
||||
dp_world_size = get_data_parallel_group_world_size()
|
||||
dp_is_prefill, dp_query_lens, dp_group_bs, seq_len_per_batch = [], [], [], []
|
||||
for i in range(0, 3 * dp_world_size, 3):
|
||||
dp_query_lens.append(outputs[i])
|
||||
dp_group_bs.append(outputs[i + 1])
|
||||
dp_is_prefill.append(bool(outputs[i + 2]))
|
||||
|
||||
# Only run communication if mcc is enabled and is prefill.
|
||||
if vllm_config.mlu_config.is_dpsk_mcc_enabled and all(dp_is_prefill):
|
||||
assert len(seq_lens) == num_requests
|
||||
seq_len_per_batch = [torch.empty([bs], dtype=dp_tensor.dtype, device=device) for bs in dp_group_bs]
|
||||
seq_lens_tensor = torch.tensor(seq_lens, dtype=dp_tensor.dtype, device=device)
|
||||
torch.distributed.all_gather(seq_len_per_batch, seq_lens_tensor, group=get_dp_group().device_group)
|
||||
seq_len_per_batch=torch.cat(seq_len_per_batch).tolist()
|
||||
else:
|
||||
seq_len_per_batch = [0] * sum(dp_group_bs)
|
||||
|
||||
return dp_query_lens, dp_group_bs, dp_is_prefill, seq_len_per_batch
|
||||
|
||||
|
||||
def get_deepseek_layer_split_list(
|
||||
dp_query_lens: List[int], dp_group_bs: List[int]
|
||||
) -> Tuple[Optional[List[int]], Optional[List[int]], Optional[List[int]]]:
|
||||
if len(dp_query_lens) != len(dp_group_bs) or len(dp_query_lens) != get_data_parallel_group_world_size():
|
||||
logger.warning(f"dp_query_lens length: {len(dp_query_lens)} != dp_group_bs length: {len(dp_group_bs)}, "
|
||||
f"disable deepseek layer split")
|
||||
return None, None, None
|
||||
emb_query_lens, logits_batch_sizes, dense_attn_token_split_list = None, None, None
|
||||
all_dp_query_lens, all_dp_group_bs = [], []
|
||||
for i in range(len(dp_query_lens)):
|
||||
all_dp_query_lens.extend([dp_query_lens[i]] * get_tensor_model_parallel_world_size())
|
||||
all_dp_group_bs.extend([dp_group_bs[i]] * get_tensor_model_parallel_world_size())
|
||||
if get_logits_tp_world_size() != get_tensor_model_parallel_world_size():
|
||||
slice_start = get_tp_world_rank() // get_logits_tp_world_size() * get_logits_tp_world_size()
|
||||
slice_end = slice_start + get_logits_tp_world_size()
|
||||
emb_query_lens = all_dp_query_lens[slice_start:slice_end]
|
||||
logits_batch_sizes = all_dp_group_bs[slice_start:slice_end]
|
||||
if get_dense_mlp_tp_world_size() != get_tp_world_world_size():
|
||||
slice_start = get_tp_world_rank() // get_dense_mlp_tp_world_size() * get_dense_mlp_tp_world_size()
|
||||
slice_end = slice_start + get_dense_mlp_tp_world_size()
|
||||
dense_attn_token_split_list = all_dp_query_lens[slice_start:slice_end]
|
||||
return emb_query_lens, logits_batch_sizes, dense_attn_token_split_list
|
||||
|
||||
|
||||
def get_dp_metadata(
|
||||
num_tokens: int,
|
||||
data_parallel_size: int,
|
||||
data_parallel_rank: int,
|
||||
tensor_parallel_size: int,
|
||||
prefill_dispatch_use_RS_AG: bool,
|
||||
) -> DataParallelRuntimeParams:
|
||||
"""
|
||||
Get dp params when dummy run or capture model graph. These two cases do not have
|
||||
dp_params when forward call, because we do not want to hijack to much.
|
||||
"""
|
||||
dp_query_lens = [num_tokens] * data_parallel_size
|
||||
in_prefill = get_forward_context().attn_metadata is None # dummy run
|
||||
dp_is_prefill = [in_prefill] * data_parallel_size
|
||||
emb_query_lens, logits_batch_sizes, dense_attn_token_split_list = None, None, None
|
||||
if get_logits_tp_world_size() != get_tensor_model_parallel_world_size():
|
||||
emb_query_lens = [num_tokens] * get_logits_tp_world_size()
|
||||
logits_batch_sizes = None # dummy run and capture model does not contain logits
|
||||
if get_dense_mlp_tp_world_size() != get_tp_world_world_size():
|
||||
dense_attn_token_split_list = [num_tokens] * get_dense_mlp_tp_world_size()
|
||||
|
||||
return MLUDPMetadata.make_oot(data_parallel_rank,
|
||||
data_parallel_size,
|
||||
tensor_parallel_size,
|
||||
dp_query_lens,
|
||||
dp_is_prefill,
|
||||
prefill_dispatch_use_RS_AG,
|
||||
emb_query_lens=emb_query_lens,
|
||||
logits_batch_sizes=logits_batch_sizes,
|
||||
dense_attn_token_split_list=dense_attn_token_split_list)
|
||||
|
||||
|
||||
def remove_paddings_after_all_gather(
|
||||
hidden_states: torch.Tensor,
|
||||
padding_to_token_num: int,
|
||||
token_num_list: List[int],
|
||||
) -> torch.Tensor:
|
||||
dp_group_tensors = []
|
||||
offset = 0
|
||||
for token_num in token_num_list:
|
||||
if token_num != 0:
|
||||
dp_group_tensors.append(hidden_states[offset:offset+token_num])
|
||||
offset += padding_to_token_num
|
||||
if len(dp_group_tensors) == 1:
|
||||
hidden_states = dp_group_tensors[0]
|
||||
else:
|
||||
hidden_states = torch.cat(dp_group_tensors)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def tensor_model_parallel_all_gather_dp(
|
||||
group_num_tokens: List[int],
|
||||
rank: int,
|
||||
hidden_states: Optional[torch.Tensor],
|
||||
group: GroupCoordinator,
|
||||
hidden_size: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None) -> torch.Tensor:
|
||||
"""
|
||||
All gather in the group.
|
||||
Input is a 2-D tensor, and can have different shape in the first dim,
|
||||
for example, [4, 7, 5, 8], [2, 5, 4, 0].
|
||||
"""
|
||||
num_tokens_equal = all(x == group_num_tokens[0] for x in group_num_tokens)
|
||||
if num_tokens_equal:
|
||||
hidden_states = tensor_model_parallel_all_gather(
|
||||
input_=hidden_states, dim=0, tp_group=group)
|
||||
else:
|
||||
max_num_tokens = max(group_num_tokens)
|
||||
num_padding = max_num_tokens - group_num_tokens[rank]
|
||||
if num_padding > 0:
|
||||
if hidden_states is None:
|
||||
hidden_states = torch.empty((max_num_tokens, hidden_size),
|
||||
dtype=dtype, device=device)
|
||||
else:
|
||||
hidden_states = F.pad(hidden_states, (0, 0, 0, num_padding))
|
||||
hidden_states = tensor_model_parallel_all_gather(
|
||||
input_=hidden_states, dim=0, tp_group=group)
|
||||
hidden_states = remove_paddings_after_all_gather(
|
||||
hidden_states, max_num_tokens, group_num_tokens)
|
||||
return hidden_states
|
||||
|
||||
def tensor_model_parallel_all_gather_op_v2(
|
||||
input_: torch.Tensor,
|
||||
dim_size_list: List[int],
|
||||
group_coordinator: GroupCoordinator,
|
||||
non_leading_dim_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
All gather the input tensor across model parallel group with only communication ops.
|
||||
|
||||
Note: compared to `tensor_model_parallel_all_gather_dp`, this method supports different
|
||||
sizes in the first dim, and does not involve padding operation.
|
||||
"""
|
||||
all_size_equal = all([dim_size == dim_size_list[0] for dim_size in dim_size_list])
|
||||
|
||||
output_shape = (sum(dim_size_list), non_leading_dim_size)
|
||||
output = torch.empty(output_shape, device=device, dtype=dtype)
|
||||
|
||||
if input_ is None:
|
||||
input_ = torch.empty((0, non_leading_dim_size), device=device, dtype=dtype)
|
||||
|
||||
if all_size_equal:
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
output, input_, group=group_coordinator.device_group)
|
||||
else:
|
||||
# Note: torch.split splits the tensor into chunks. And each chunk
|
||||
# is a view of the original tensor.
|
||||
tensor_list = torch.split(output, dim_size_list, dim=0)
|
||||
torch.distributed.all_gather(
|
||||
list(tensor_list), input_, group=group_coordinator.device_group)
|
||||
return output
|
||||
|
||||
def process_post_attention_communication(
|
||||
hidden_states: Optional[torch.Tensor],
|
||||
dp_params: DataParallelRuntimeParams,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
tp_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Processes distributed communication operations after attention computation.
|
||||
|
||||
This function performs necessary communication operations after attention computation
|
||||
to ensure data synchronization across different parallel groups.
|
||||
Supports two modes:
|
||||
1. Tensor parallel mode: Uses tp_group for all-reduce and all-gather operations
|
||||
2. Data parallel mode: Uses reduce-scatter and all-gather for global synchronization
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states tensor after attention computation, can be None
|
||||
dp_params: Data parallel runtime parameters containing token distribution and padding info
|
||||
hidden_size: Dimension size of hidden states
|
||||
dtype: Data type of the tensor
|
||||
device: Device where the tensor is located
|
||||
tp_group: Tensor parallel group, if None uses data parallel mode
|
||||
|
||||
Returns:
|
||||
Hidden states tensor after communication synchronization processing
|
||||
|
||||
Note:
|
||||
- When prefill_pad_to_token_num != -1, padding and unpadding operations will be performed
|
||||
- Function selects optimal communication path based on token count and parallel strategy
|
||||
"""
|
||||
if tp_group is not None:
|
||||
if dp_params.token_num != 0:
|
||||
hidden_states = tensor_model_parallel_all_reduce(
|
||||
hidden_states)
|
||||
hidden_states = tensor_model_parallel_all_gather_dp(
|
||||
group_num_tokens=dp_params.dense_attn_token_split_list,
|
||||
rank=get_parallel_rank_with_group(tp_group),
|
||||
hidden_states=hidden_states,
|
||||
group=tp_group,
|
||||
)
|
||||
else:
|
||||
if dp_params.prefill_pad_to_token_num != -1:
|
||||
# pad hidden_states to use reduce_scatter and global all gather
|
||||
pad_num = dp_params.prefill_pad_to_token_num - dp_params.token_num
|
||||
if pad_num != 0:
|
||||
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_num))
|
||||
|
||||
hidden_states = tensor_model_parallel_reduce_scatter(
|
||||
hidden_states, dim=0)
|
||||
|
||||
hidden_states = tensor_model_parallel_all_gather_dp(
|
||||
group_num_tokens=dp_params.attn_token_split_list_reduce_scatter,
|
||||
rank=get_tp_world_rank(),
|
||||
hidden_states=hidden_states,
|
||||
group=get_tp_world_group(),
|
||||
)
|
||||
|
||||
# get origin hidden_states for moe compute
|
||||
hidden_states = remove_paddings_after_all_gather(
|
||||
hidden_states, dp_params.prefill_pad_to_token_num,
|
||||
dp_params.token_split_list)
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(
|
||||
hidden_states)
|
||||
|
||||
all_gather_group = get_dp_group()
|
||||
all_gather_rank = get_data_parallel_group_rank()
|
||||
hidden_states = tensor_model_parallel_all_gather_dp(
|
||||
dp_params.token_split_list, all_gather_rank, hidden_states,
|
||||
all_gather_group, hidden_size, dtype, device)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def dp_model_forward(
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor],
|
||||
dp_params: DataParallelRuntimeParams,
|
||||
embedding_layer: nn.Module,
|
||||
model_norm_layer: nn.Module,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
layers: List[nn.Module],
|
||||
layer_input_norm_name: str,
|
||||
prefill_dispatch_use_RS_AG: bool,
|
||||
streams: Optional[Dict[str, torch.mlu.Stream]] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
"""run model with dp."""
|
||||
if dp_params is None:
|
||||
dp_params = get_dp_metadata(positions.numel(),
|
||||
get_data_parallel_group_world_size(),
|
||||
get_data_parallel_group_rank(),
|
||||
get_tensor_model_parallel_world_size(),
|
||||
prefill_dispatch_use_RS_AG)
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
if embedding_layer.__class__.__name__ == "DPVocabParallelEmbedding":
|
||||
hidden_states = embedding_layer(input_ids, dp_params=dp_params)
|
||||
else:
|
||||
hidden_states = embedding_layer(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(start_layer, end_layer):
|
||||
is_first_layer = (i == start_layer)
|
||||
is_last_layer = (i == end_layer - 1)
|
||||
next_input_layernorm = None
|
||||
if not is_last_layer:
|
||||
next_input_layernorm = getattr(layers[i+1], layer_input_norm_name)
|
||||
hidden_states, residual = layers[i](
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
dp_params=dp_params,
|
||||
is_first_layer=is_first_layer,
|
||||
is_last_layer=is_last_layer,
|
||||
streams=streams,
|
||||
next_input_layernorm=next_input_layernorm,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states = model_norm_layer(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def dp_layer_forward(
|
||||
input_norm: nn.Module,
|
||||
self_attn: nn.Module,
|
||||
post_norm: nn.Module,
|
||||
mlp: nn.Module,
|
||||
mlp_kwargs: List[Dict[str, Any]],
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
dp_params: DataParallelRuntimeParams,
|
||||
hidden_size: int,
|
||||
hidden_states_dtype: torch.dtype,
|
||||
is_first_layer: bool = False,
|
||||
is_last_layer: bool = False,
|
||||
next_input_layernorm: Optional[nn.Module] = None,
|
||||
enable_all2all: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
run layer with dp. dispatch all2all or rs+ag or common.
|
||||
|
||||
For mlp_kwargs, because all2all forward args is often different with common mlp args.
|
||||
So here we decide that the mlp_kwargs[-1] is always all2all kwargs. For example:
|
||||
|
||||
Deepseek enable all2all, mlp_kwargs will be: [{mlp common forward kwargs}, {mlp all2all kwargs}].
|
||||
Deepseek does not enable all2all, mlp_kwargs will be: [{mlp common forward kwargs}].
|
||||
"""
|
||||
if dp_params.layer_use_reduce_scatter:
|
||||
common_metadata = get_common_metadata()
|
||||
is_decode_only = common_metadata is not None and common_metadata.is_decode_only
|
||||
use_all2all = enable_all2all and is_decode_only and isinstance(mlp, SparseMoeMlp)
|
||||
forward_func = _dp_forward_layer_all2all if use_all2all else _dp_forward_layer_rs_ag
|
||||
hidden_states, residual = forward_func(input_norm,
|
||||
self_attn,
|
||||
post_norm,
|
||||
mlp,
|
||||
mlp_kwargs,
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
dp_params,
|
||||
is_first_layer,
|
||||
is_last_layer,
|
||||
next_input_layernorm)
|
||||
else:
|
||||
hidden_states, residual = _dp_forward_layer_common(input_norm,
|
||||
self_attn,
|
||||
post_norm,
|
||||
mlp,
|
||||
mlp_kwargs,
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
dp_params,
|
||||
hidden_size,
|
||||
hidden_states_dtype)
|
||||
return hidden_states, residual
|
||||
|
||||
def _dp_forward_layer_rs_ag(
|
||||
input_norm: nn.Module,
|
||||
self_attn: nn.Module,
|
||||
post_norm: nn.Module,
|
||||
mlp: nn.Module,
|
||||
mlp_kwargs: List[Dict[str, Any]],
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
dp_params: DataParallelRuntimeParams,
|
||||
is_first_layer: bool,
|
||||
is_last_layer: bool,
|
||||
next_input_layernorm: List[Optional[nn.Module]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""run layer with rs+ag."""
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
|
||||
# We move the input_layernorm of i+1 layer to the end of i layer.
|
||||
# But for the first layer, we need to do input_layernorm first.
|
||||
if is_first_layer:
|
||||
hidden_states = input_norm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
# add residual here for the first layer
|
||||
if is_first_layer and get_tensor_model_parallel_rank() == 0:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = tensor_model_parallel_reduce_scatter(
|
||||
hidden_states, dim=0)
|
||||
|
||||
# move norm between rs and ag
|
||||
if is_first_layer:
|
||||
residual = hidden_states
|
||||
hidden_states = post_norm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = post_norm(hidden_states, residual)
|
||||
|
||||
hidden_states = tensor_model_parallel_all_gather_dp(
|
||||
group_num_tokens=dp_params.attn_token_split_list_reduce_scatter,
|
||||
rank=get_tp_world_rank(),
|
||||
hidden_states=hidden_states,
|
||||
group=get_tp_world_group(),
|
||||
)
|
||||
|
||||
# mlp, use all cards
|
||||
hidden_states = mlp(hidden_states, **mlp_kwargs[0])
|
||||
|
||||
hidden_states = tensor_model_parallel_reduce_scatter(
|
||||
hidden_states, dim=0, tp_group=get_tp_world_group())
|
||||
|
||||
if is_last_layer:
|
||||
hidden_states = hidden_states + residual
|
||||
residual = None
|
||||
else:
|
||||
# To reduce layernorm computation, we move the layernorm of i+1 layer to
|
||||
# the end of i layer. Besides, we fuse residual addition into layernorm.
|
||||
assert next_input_layernorm is not None
|
||||
hidden_states, residual = next_input_layernorm(hidden_states, residual)
|
||||
|
||||
hidden_states = tensor_model_parallel_all_gather_dp(
|
||||
group_num_tokens=dp_params.moe_token_split_list_reduce_scatter,
|
||||
rank=get_tensor_model_parallel_rank(),
|
||||
hidden_states=hidden_states,
|
||||
group=get_tp_group(),
|
||||
)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
def _dp_forward_layer_all2all(
|
||||
input_norm: nn.Module,
|
||||
self_attn: nn.Module,
|
||||
post_norm: nn.Module,
|
||||
mlp: nn.Module,
|
||||
mlp_kwargs: List[Dict[str, Any]],
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
dp_params: DataParallelRuntimeParams,
|
||||
is_first_layer: bool,
|
||||
is_last_layer: bool,
|
||||
next_input_layernorm: List[Optional[nn.Module]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""run layer with all2all."""
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
|
||||
# We move the input_layernorm of i+1 layer to the end of i layer.
|
||||
# But for the first layer, we need to do input_layernorm first.
|
||||
if is_first_layer:
|
||||
hidden_states = input_norm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
# add residual here for the first layer
|
||||
if is_first_layer and get_tensor_model_parallel_rank() == 0:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = tensor_model_parallel_reduce_scatter(
|
||||
hidden_states, dim=0)
|
||||
|
||||
# move norm between rs and ag
|
||||
if is_first_layer:
|
||||
residual = hidden_states
|
||||
hidden_states = post_norm(hidden_states)
|
||||
else:
|
||||
# add residual in norm for other layers
|
||||
hidden_states, residual = post_norm(hidden_states, residual)
|
||||
|
||||
hidden_states = mlp.forward_all2all(hidden_states, **mlp_kwargs[-1])
|
||||
|
||||
if is_last_layer:
|
||||
hidden_states = hidden_states + residual
|
||||
residual = None
|
||||
else:
|
||||
# To reduce layernorm computation, we move the layernorm of i+1 layer to
|
||||
# the end of i layer. Besides, we fuse residual addition into layernorm.
|
||||
assert next_input_layernorm is not None
|
||||
hidden_states, residual = next_input_layernorm(hidden_states, residual)
|
||||
|
||||
hidden_states = tensor_model_parallel_all_gather_dp(
|
||||
group_num_tokens=dp_params.moe_token_split_list_reduce_scatter,
|
||||
rank=get_tensor_model_parallel_rank(),
|
||||
hidden_states=hidden_states,
|
||||
group=get_tp_group(),
|
||||
)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
def _dp_forward_layer_common(
|
||||
input_norm: nn.Module,
|
||||
self_attn: nn.Module,
|
||||
post_norm: nn.Module,
|
||||
mlp: nn.Module,
|
||||
mlp_kwargs: List[Dict[str, Any]],
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
dp_params: DataParallelRuntimeParams,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""run layer with common."""
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = input_norm(hidden_states)
|
||||
hidden_states = self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
# add residual here
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = process_post_attention_communication(
|
||||
hidden_states, dp_params, hidden_size, dtype, positions.device, None
|
||||
)
|
||||
|
||||
residual = hidden_states[dp_params.token_num_offset:
|
||||
dp_params.token_num_offset + dp_params.token_num]
|
||||
|
||||
hidden_states = post_norm(hidden_states)
|
||||
|
||||
hidden_states = mlp(hidden_states, **mlp_kwargs[0])
|
||||
|
||||
hidden_states = tensor_model_parallel_all_reduce(
|
||||
hidden_states, tp_group=get_tp_world_group())
|
||||
|
||||
# add residual here
|
||||
hidden_states = hidden_states[dp_params.token_num_offset:
|
||||
dp_params.token_num_offset+dp_params.token_num]
|
||||
hidden_states = hidden_states + residual
|
||||
residual = hidden_states
|
||||
|
||||
return hidden_states, residual
|
||||
245
vllm_mlu/model_executor/models/layer_utils.py
Executable file
245
vllm_mlu/model_executor/models/layer_utils.py
Executable file
@@ -0,0 +1,245 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
from typing import Callable, Optional, List, Union, Tuple
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
def hunyuan_decoder_layer_forward_base(
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_layernorm: Callable,
|
||||
self_attn: Callable,
|
||||
post_layernorm: Callable,
|
||||
mlp: Callable,
|
||||
kv_states: Optional[Tuple[torch.Tensor]] = None,
|
||||
apply_residual_connection_post_layernorm: bool = False,
|
||||
position_name: str = 'positions',
|
||||
input_norm_fuse_en: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
smooth_quant_scale = None
|
||||
if input_norm_fuse_en:
|
||||
layernorm_output, smooth_quant_scale = input_layernorm(hidden_states)
|
||||
else:
|
||||
layernorm_output = input_layernorm(hidden_states)
|
||||
smooth_quant_scale = None
|
||||
if apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = hidden_states
|
||||
|
||||
# Self Attention
|
||||
attention_output, ori_kv_states = self_attn(
|
||||
**{position_name: positions},
|
||||
hidden_states=layernorm_output,
|
||||
residual=residual,
|
||||
kv_states=kv_states,
|
||||
smooth_quant_scale=smooth_quant_scale,
|
||||
)
|
||||
|
||||
layernorm_output = post_layernorm(attention_output)
|
||||
if apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = attention_output
|
||||
|
||||
# Fully Connected
|
||||
hidden_states = mlp(layernorm_output, residual)
|
||||
return hidden_states, ori_kv_states
|
||||
|
||||
|
||||
def decoder_layer_forward_base(
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_layernorm: Callable,
|
||||
self_attn: Callable,
|
||||
post_layernorm: Callable,
|
||||
mlp: Callable,
|
||||
apply_residual_connection_post_layernorm: bool = False,
|
||||
position_name: str = 'positions',
|
||||
input_norm_fuse_en: bool = False,
|
||||
post_norm_fuse_en: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if input_norm_fuse_en:
|
||||
layernorm_output, smooth_quant_scale = input_layernorm(hidden_states)
|
||||
else:
|
||||
layernorm_output = input_layernorm(hidden_states)
|
||||
smooth_quant_scale = None
|
||||
|
||||
if apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = hidden_states
|
||||
|
||||
# Self Attention
|
||||
attention_output = self_attn(
|
||||
**{position_name: positions},
|
||||
hidden_states=layernorm_output,
|
||||
residual=residual,
|
||||
smooth_quant_scale=smooth_quant_scale,
|
||||
)
|
||||
|
||||
if post_norm_fuse_en:
|
||||
layernorm_output, smooth_quant_scale = post_layernorm(attention_output)
|
||||
else:
|
||||
layernorm_output = post_layernorm(attention_output)
|
||||
smooth_quant_scale = None
|
||||
|
||||
if apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = attention_output
|
||||
|
||||
# Fully Connected
|
||||
kwargs = dict()
|
||||
if post_norm_fuse_en:
|
||||
kwargs['smooth_quant_scale'] = smooth_quant_scale
|
||||
hidden_states = mlp(layernorm_output, residual, **kwargs)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def decoder_model_forward_base(
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
layers: torch.nn.ModuleList,
|
||||
embed_input_ids: Callable,
|
||||
norm: Callable
|
||||
) -> torch.Tensor:
|
||||
hidden_states = embed_input_ids(input_ids)
|
||||
for i in range(len(layers)):
|
||||
layer = layers[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
)
|
||||
hidden_states = norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def hunyuan_decoder_model_forward_base_pp(
|
||||
config: PretrainedConfig,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
layers: torch.nn.ModuleList,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
embed_input_ids: Callable,
|
||||
norm: Callable,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = embed_input_ids(input_ids)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
cla_factor = getattr(config, "cla_share_factor", 1)
|
||||
prev_kv_states = None
|
||||
for i in range(start_layer, end_layer):
|
||||
layer = layers[i]
|
||||
hidden_states, kv_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
prev_kv_states,
|
||||
)
|
||||
if (i - start_layer) % cla_factor == 0:
|
||||
prev_kv_states = kv_states
|
||||
else:
|
||||
prev_kv_states = None
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
})
|
||||
|
||||
hidden_states = norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def decoder_model_forward_base_pp(
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
layers: torch.nn.ModuleList,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
embed_input_ids: Callable,
|
||||
norm: Callable,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = embed_input_ids(input_ids)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
for i in range(start_layer, end_layer):
|
||||
layer = layers[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
})
|
||||
|
||||
hidden_states = norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def is_smoothquant(quant_config: QuantizationConfig) -> bool:
|
||||
return (quant_config is not None
|
||||
and quant_config.get_name() == "SmoothQuant")
|
||||
|
||||
|
||||
def is_per_token_smoothquant(quant_config: QuantizationConfig) -> bool:
|
||||
return (is_smoothquant(quant_config)
|
||||
and quant_config.input_quant_method == "per_token")
|
||||
|
||||
def compute_in_loop(func: Callable,
|
||||
input: torch.Tensor,
|
||||
chunk_size: int,
|
||||
feature_size: Optional[int] = None,
|
||||
**kwargs):
|
||||
"""
|
||||
divides input into chunks in the leading dimension (dimension 0), and
|
||||
compute the chunks in a loop, instead of in a batch at once.
|
||||
|
||||
arg:
|
||||
feature_size: size of output feature dimension. Provide it when the
|
||||
the output's feature dimension would differ from the input's
|
||||
feature dimension.
|
||||
"""
|
||||
|
||||
total = input.shape[0]
|
||||
# directly compute if there is only one chunk
|
||||
if chunk_size >= total:
|
||||
return func(input, **kwargs)
|
||||
|
||||
feature_size = feature_size or input.shape[1]
|
||||
output = input.new_empty(total, feature_size)
|
||||
num_chunks = (total + chunk_size - 1) // chunk_size
|
||||
|
||||
for i in range(num_chunks):
|
||||
start = i * chunk_size
|
||||
end = min((i + 1) * chunk_size, total)
|
||||
output[start : end] = func(input[start : end], **kwargs)
|
||||
|
||||
return output
|
||||
507
vllm_mlu/model_executor/models/partition_utils.py
Normal file
507
vllm_mlu/model_executor/models/partition_utils.py
Normal file
@@ -0,0 +1,507 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import itertools
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm_mlu.mlu_forward_context import MLUDPMetadata
|
||||
from vllm_mlu.model_executor.models.dp_utils import DataParallelRuntimeParams
|
||||
from vllm_mlu.v1.attention.backends.mla.flashmla import (
|
||||
FlashMLAPrefillMetadata, FlashMLAMetadata, MLACommonMetadata
|
||||
)
|
||||
from vllm_mlu.v1.attention.backends.utils import (
|
||||
COMMON_METADATA_STR,
|
||||
MLUCommonAttentionMetadata,
|
||||
)
|
||||
|
||||
SEQUENCE_DIM_PARITION_THRESHOLD = 1024
|
||||
|
||||
def get_common_and_layer_metadata(
|
||||
attn_metadata: Optional[dict],
|
||||
) -> Tuple[Optional[MLUCommonAttentionMetadata], Optional[AttentionMetadata]]:
|
||||
"""
|
||||
Returns the common metadata and layer metadata from the given attention metadata.
|
||||
"""
|
||||
if attn_metadata is None:
|
||||
return None, None
|
||||
|
||||
if isinstance(attn_metadata, dict):
|
||||
assert COMMON_METADATA_STR in attn_metadata, (
|
||||
f"attn_metadata must contain {COMMON_METADATA_STR} key"
|
||||
)
|
||||
assert len({id(v) for v in attn_metadata.values()}) == 2, (
|
||||
f"attn_metadata should be a dict with two values, one for {COMMON_METADATA_STR} and "
|
||||
f"the other for layers."
|
||||
)
|
||||
common_metadata = attn_metadata[COMMON_METADATA_STR]
|
||||
layer_metadata = next((v for k, v in attn_metadata.items() if k != COMMON_METADATA_STR), None)
|
||||
|
||||
return common_metadata, layer_metadata
|
||||
return None, attn_metadata
|
||||
|
||||
def should_skip_partition(layer_metadata, common_metadata) -> bool:
|
||||
"""Helper function to simplify partition condition check"""
|
||||
is_layer_metadata_invalid = (layer_metadata is None
|
||||
or layer_metadata.prefill is None
|
||||
or layer_metadata.query_start_loc is None
|
||||
or layer_metadata.query_start_loc.numel() == 0)
|
||||
is_common_metadata_invalid = common_metadata is None or not common_metadata.is_prefill_only
|
||||
return is_layer_metadata_invalid or is_common_metadata_invalid
|
||||
|
||||
|
||||
def attn_mcc_plan(
|
||||
attn_metadata: Any,
|
||||
dp_params: DataParallelRuntimeParams,
|
||||
parts_to_split: int,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Returns the number of parts for batch size dimension and the number of parts for sequence length dimension.
|
||||
"""
|
||||
# In the precedure of dummy run, attn_metadata is an instance of MLACommonMetadata
|
||||
if not isinstance(attn_metadata, (dict, MLACommonMetadata, type(None))):
|
||||
raise TypeError(f"attn_metadata must be dict or MLACommonMetadata, got {type(attn_metadata)}")
|
||||
|
||||
if isinstance(attn_metadata, dict):
|
||||
common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata)
|
||||
else:
|
||||
common_metadata, layer_metadata = None, attn_metadata
|
||||
|
||||
if dp_params is None:
|
||||
# We don't support mcc with decode yet.
|
||||
if should_skip_partition(layer_metadata, common_metadata):
|
||||
return 1, 1
|
||||
|
||||
# The priority of batch size dimension to split is higher than sequence length dimension.
|
||||
# And we ensure each subtask is not empty without dp.
|
||||
num_prefills = layer_metadata.query_start_loc.numel() - 1
|
||||
if num_prefills > 1:
|
||||
return min(parts_to_split, num_prefills), 1
|
||||
|
||||
try:
|
||||
max_query_len = torch.diff(layer_metadata.query_start_loc).max().item()
|
||||
except RuntimeError:
|
||||
return 1, 1
|
||||
|
||||
if max_query_len < SEQUENCE_DIM_PARITION_THRESHOLD:
|
||||
return 1, 1
|
||||
return 1, min(parts_to_split, max_query_len)
|
||||
else:
|
||||
if not all(is_prefill for is_prefill in dp_params.dp_is_prefill):
|
||||
return 1, 1
|
||||
|
||||
max_bs = max(dp_params.batch_sizes)
|
||||
if max_bs > 1:
|
||||
# Ensure parts_to_split does not exceed max_bs to avoid unnecessary splits
|
||||
if max(dp_params.token_split_list) < SEQUENCE_DIM_PARITION_THRESHOLD:
|
||||
return 1, 1
|
||||
return min(parts_to_split, max_bs), 1
|
||||
else:
|
||||
if max(dp_params.token_split_list) < SEQUENCE_DIM_PARITION_THRESHOLD:
|
||||
return 1, 1
|
||||
return 1, parts_to_split
|
||||
|
||||
def get_data_num_and_offset(total_size, parts_to_split):
|
||||
"""
|
||||
Get data size and offset for each.
|
||||
For example, total batch 11, parallel_num 4, result is [3, 3, 3, 2], offsets is [0, 3, 6, 9]
|
||||
total batch 8, parallel_num 4, result is [2, 2, 2, 2], offsets is [0, 2, 4, 6]
|
||||
"""
|
||||
# Calculate the quotient and remainder of total_size divided by parts_to_split
|
||||
quotient = total_size // parts_to_split
|
||||
remainder = total_size % parts_to_split
|
||||
data_num_list = [quotient + 1] * remainder + [quotient] * (parts_to_split - remainder)
|
||||
offset_list = [0] + list(itertools.accumulate(data_num_list))
|
||||
return data_num_list, offset_list[:-1]
|
||||
|
||||
def split_dp_params(
|
||||
dp_params: DataParallelRuntimeParams,
|
||||
bs_parts_to_split: int,
|
||||
seq_parts_to_split: int,
|
||||
attn_data_parallel_size: int,
|
||||
attn_tensor_parallel_size: int,
|
||||
prefill_dispatch_use_RS_AG: bool,
|
||||
dp_rank_: int,
|
||||
) -> List[DataParallelRuntimeParams]:
|
||||
assert bs_parts_to_split == 1 or seq_parts_to_split == 1, \
|
||||
"We don't support split batch and sequence dimensions concurrently."
|
||||
|
||||
if dp_params is None:
|
||||
return [None] * bs_parts_to_split * seq_parts_to_split
|
||||
|
||||
if bs_parts_to_split * seq_parts_to_split == 1:
|
||||
return list([dp_params])
|
||||
|
||||
if bs_parts_to_split == 1:
|
||||
results : List[DataParallelRuntimeParams] = []
|
||||
dp_seq_lens = []
|
||||
for seq_len in dp_params.seq_lens:
|
||||
tokens, _ = get_data_num_and_offset(seq_len, seq_parts_to_split)
|
||||
dp_seq_lens.append(tokens)
|
||||
|
||||
query_lens_per_dp_rank = []
|
||||
|
||||
# For each dp rank, the batch size is 0 or 1.
|
||||
bs_offset = 0
|
||||
for i in range(attn_data_parallel_size):
|
||||
if dp_params.batch_sizes[i] > 0:
|
||||
seq_len = dp_params.seq_lens[bs_offset]
|
||||
tokens, _ = get_data_num_and_offset(seq_len, seq_parts_to_split)
|
||||
query_lens_per_dp_rank.append(tokens)
|
||||
bs_offset += dp_params.batch_sizes[i]
|
||||
else:
|
||||
query_lens_per_dp_rank.append([0] * seq_parts_to_split)
|
||||
|
||||
for i in range(seq_parts_to_split):
|
||||
dp_is_prefill = []
|
||||
for dp_rank in range(attn_data_parallel_size):
|
||||
dp_is_prefill.append(True)
|
||||
|
||||
results.append(MLUDPMetadata.make_oot(
|
||||
data_parallel_rank=dp_rank_,
|
||||
data_parallel_size=attn_data_parallel_size,
|
||||
tensor_parallel_size=attn_tensor_parallel_size,
|
||||
dp_token_nums=[query_lens_per_dp_rank[j][i] for j in range(attn_data_parallel_size)],
|
||||
dp_is_prefill=dp_is_prefill,
|
||||
prefill_dispatch_use_RS_AG=prefill_dispatch_use_RS_AG,
|
||||
seq_lens=[seq_lens[i] for seq_lens in dp_seq_lens],
|
||||
batch_sizes=dp_params.batch_sizes,
|
||||
))
|
||||
return results
|
||||
|
||||
bs_per_dp = dp_params.batch_sizes # [bs_rank_0, bs_rank_1, ...]
|
||||
seq_lens_per_dp = dp_params.seq_lens # [seq_len_bs_0, seq_len_bs_1,...]
|
||||
|
||||
# [[bs_rank_0_part_0, bs_rank_0_part_1,...], [bs_rank_1_part_0, bs_rank_1_part_1,...], ...]
|
||||
split_bs_per_dp = []
|
||||
# [[
|
||||
# [bs0_part_0_rank_0, bs1_part_0_rank_0, ...],
|
||||
# [bs0_part_1_rank_0, bs1_part_1_rank_0, ...],
|
||||
# ...
|
||||
# ],
|
||||
# [
|
||||
# [bs0_part_0_rank_1, bs1_part_0_rank_1, ...],
|
||||
# [bs0_part_1_rank_1, bs1_part_1_rank_1, ...],
|
||||
# ...
|
||||
# ],
|
||||
# ]
|
||||
split_query_lens_per_dp = []
|
||||
for dp_rank in range(attn_data_parallel_size):
|
||||
_bs, _offset = get_data_num_and_offset(bs_per_dp[dp_rank], bs_parts_to_split)
|
||||
split_bs_per_dp.append(_bs)
|
||||
split_query_lens_per_dp.append([])
|
||||
for i in range(bs_parts_to_split):
|
||||
start = sum(bs_per_dp[:dp_rank]) + _offset[i]
|
||||
end = start + _bs[i]
|
||||
split_query_lens_per_dp[-1].append(dp_params.seq_lens[start:end])
|
||||
|
||||
results : List[DataParallelRuntimeParams] = []
|
||||
for i in range(bs_parts_to_split):
|
||||
dp_query_lens = [sum(split_query_lens_per_dp[dp_rank][i]) for dp_rank in range(attn_data_parallel_size)]
|
||||
seq_lens = []
|
||||
for dp_rank in range(attn_data_parallel_size):
|
||||
seq_lens += split_query_lens_per_dp[dp_rank][i]
|
||||
batch_sizes = []
|
||||
for dp_rank in range(attn_data_parallel_size):
|
||||
batch_sizes.append(split_bs_per_dp[dp_rank][i])
|
||||
|
||||
dp_is_prefill = []
|
||||
for dp_rank in range(attn_data_parallel_size):
|
||||
dp_is_prefill.append(True)
|
||||
results.append(MLUDPMetadata.make_oot(
|
||||
data_parallel_rank=dp_rank_,
|
||||
data_parallel_size=attn_data_parallel_size,
|
||||
tensor_parallel_size=attn_tensor_parallel_size,
|
||||
dp_token_nums=dp_query_lens,
|
||||
dp_is_prefill=dp_is_prefill,
|
||||
prefill_dispatch_use_RS_AG=prefill_dispatch_use_RS_AG,
|
||||
seq_lens=seq_lens,
|
||||
batch_sizes=batch_sizes,
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def split_input(
|
||||
input: torch.Tensor,
|
||||
bs_parts_to_split: int,
|
||||
seq_parts_to_split: int,
|
||||
attn_metadata_list: List[AttentionMetadata],
|
||||
) -> List[torch.Tensor]:
|
||||
assert seq_parts_to_split == 1 or bs_parts_to_split == 1, \
|
||||
"We don't support split batch and sequence dimensions concurrently."
|
||||
|
||||
if input is None:
|
||||
return [None] * bs_parts_to_split * seq_parts_to_split
|
||||
|
||||
if bs_parts_to_split * seq_parts_to_split == 1:
|
||||
return list([input])
|
||||
|
||||
token_num_list = [0] * len(attn_metadata_list)
|
||||
for i, metadata in enumerate(attn_metadata_list):
|
||||
common_metadata, layer_metadata = get_common_and_layer_metadata(metadata)
|
||||
if layer_metadata is not None:
|
||||
token_num_list[i] = layer_metadata.num_actual_tokens
|
||||
|
||||
# A special case for dummy run
|
||||
if layer_metadata is None and i == 0:
|
||||
token_num_list[i] = input.shape[0]
|
||||
|
||||
results = list()
|
||||
for i in range(bs_parts_to_split * seq_parts_to_split):
|
||||
start = sum(token_num_list[:i])
|
||||
end = start + token_num_list[i]
|
||||
results.append(input[start:end])
|
||||
return results
|
||||
|
||||
|
||||
def split_positions(
|
||||
positions: torch.Tensor,
|
||||
bs_parts_to_split: int,
|
||||
seq_parts_to_split: int,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> List[torch.Tensor]:
|
||||
if seq_parts_to_split == 1:
|
||||
return [positions] * bs_parts_to_split
|
||||
|
||||
common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata)
|
||||
total_tokens = layer_metadata.num_actual_tokens if layer_metadata is not None else 0
|
||||
|
||||
tokens, offsets = get_data_num_and_offset(total_tokens, seq_parts_to_split)
|
||||
positions_list = []
|
||||
for i in range(seq_parts_to_split):
|
||||
positions_list.append(positions[offsets[i]: offsets[i] + tokens[i]])
|
||||
|
||||
return positions_list
|
||||
|
||||
def split_attn_metadata(
|
||||
attn_metadata: dict,
|
||||
bs_parts_to_split: int,
|
||||
seq_parts_to_split: int,
|
||||
) -> List[Any]:
|
||||
""" attn_metdata is a dict, which contains common and layer metadata."""
|
||||
assert bs_parts_to_split == 1 or seq_parts_to_split == 1, \
|
||||
"We don't support split batch and sequence dimensions concurrently."
|
||||
if bs_parts_to_split == 1 and seq_parts_to_split == 1:
|
||||
return list([attn_metadata])
|
||||
|
||||
if attn_metadata is None:
|
||||
return [None] * bs_parts_to_split * seq_parts_to_split
|
||||
|
||||
if seq_parts_to_split > 1:
|
||||
common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata)
|
||||
if common_metadata is None or not hasattr(common_metadata, 'num_actual_tokens'):
|
||||
raise ValueError("common_metadata is invalid or missing num_actual_tokens")
|
||||
num_prefill_tokens = common_metadata.num_actual_tokens
|
||||
tokens, offsets = get_data_num_and_offset(num_prefill_tokens, seq_parts_to_split)
|
||||
device = common_metadata.seq_lens.device
|
||||
sub_common_metadata, sub_layer_metadata = [], []
|
||||
for i in range(seq_parts_to_split):
|
||||
# query_start_loc tensor, which indices positions in input.
|
||||
query_start_loc_tensor = torch.empty_like(common_metadata.query_start_loc)
|
||||
query_start_loc_tensor[0] = 0
|
||||
query_start_loc_tensor[1] = tokens[i]
|
||||
# seq_lens tensor
|
||||
seq_lens_tensor = torch.tensor(
|
||||
[offsets[i] + tokens[i]],
|
||||
dtype=common_metadata.seq_lens.dtype,
|
||||
device=device
|
||||
)
|
||||
# seq_start_loc tensor, which indicates positions in the sequence(kv cache).
|
||||
seq_start_loc_tensor = torch.empty_like(common_metadata.seq_start_loc)
|
||||
seq_start_loc_tensor[0] = offsets[i]
|
||||
seq_start_loc_tensor[1] = offsets[i] + tokens[i]
|
||||
# max_query_len scalar
|
||||
max_query_len = tokens[i]
|
||||
# num_actual_tokens scalar
|
||||
num_actual_tokens = tokens[i]
|
||||
# num_input_tokens scalar
|
||||
num_input_tokens = num_actual_tokens
|
||||
# infer_mode
|
||||
infer_mode = common_metadata.infer_mode
|
||||
# update common metadata
|
||||
sub_common_metadata.append(MLUCommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
query_start_loc_cpu=common_metadata.query_start_loc_cpu, # FIXME: split when used
|
||||
seq_lens=seq_lens_tensor,
|
||||
seq_lens_cpu=common_metadata.seq_lens_cpu, # FIXME: split when used
|
||||
num_computed_tokens_cpu=common_metadata.num_computed_tokens_cpu, # FIXME: split when used
|
||||
num_reqs=common_metadata.num_reqs, # FIXME: split when used
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_query_len,
|
||||
block_table_tensor=common_metadata.block_table_tensor, # FIXME: split when used
|
||||
slot_mapping=common_metadata.slot_mapping, # FIXME: split when used
|
||||
seq_start_loc=seq_start_loc_tensor,
|
||||
num_input_tokens=num_input_tokens,
|
||||
infer_mode=infer_mode,
|
||||
num_prefill_query_tokens=tokens[i],
|
||||
num_prefill_kv_tokens=offsets[i] + tokens[i],
|
||||
))
|
||||
# slot_mapping tensor
|
||||
slot_mapping = layer_metadata.slot_mapping[offsets[i]:offsets[i] + tokens[i]]
|
||||
# update layer metadata
|
||||
REQUIRED_NUM_DECODES = 0
|
||||
REQUIRED_NUM_DECODE_TOKENS = 0
|
||||
REQUIRED_NUM_PREFILLS = 1
|
||||
if not hasattr(layer_metadata, 'num_prefills') or \
|
||||
layer_metadata.num_prefills is None:
|
||||
raise ValueError("layer_metadata.num_prefills is required")
|
||||
|
||||
assert layer_metadata.num_decodes == REQUIRED_NUM_DECODES and \
|
||||
layer_metadata.num_decode_tokens == REQUIRED_NUM_DECODE_TOKENS and \
|
||||
layer_metadata.num_prefills == REQUIRED_NUM_PREFILLS, (
|
||||
f"num_decodes, num_decode_tokens, num_prefills must be {REQUIRED_NUM_DECODES}, {REQUIRED_NUM_DECODE_TOKENS}, "
|
||||
f"{REQUIRED_NUM_PREFILLS}, but got {layer_metadata.num_decodes}, {layer_metadata.num_decode_tokens}, "
|
||||
f"{layer_metadata.num_prefills}."
|
||||
)
|
||||
assert layer_metadata.prefill.chunked_context is None, (
|
||||
f"chunked_context is only available for prefill with chunked context, "
|
||||
f"and it is not supported when enabling mcc."
|
||||
)
|
||||
prefill_metadata = FlashMLAPrefillMetadata(
|
||||
block_table=layer_metadata.prefill.block_table,
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
max_query_len=max_query_len,
|
||||
chunked_context=None,
|
||||
num_prefills=layer_metadata.prefill.num_prefills,
|
||||
max_seq_len=layer_metadata.prefill.max_seq_len,
|
||||
)
|
||||
# Note: for sequence dimension partition, we provide cu_seqlens_kv filed to
|
||||
# indicates key/value size for flash attention operator.
|
||||
prefill_metadata.cu_seqlens_kv = torch.empty_like(prefill_metadata.query_start_loc)
|
||||
prefill_metadata.cu_seqlens_kv[0] = 0
|
||||
prefill_metadata.cu_seqlens_kv[1] = offsets[i] + tokens[i]
|
||||
|
||||
sub_layer_metadata.append(FlashMLAMetadata(
|
||||
num_reqs=layer_metadata.num_reqs,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_query_len,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
num_decodes=layer_metadata.num_decodes,
|
||||
num_decode_tokens=layer_metadata.num_decode_tokens,
|
||||
num_prefills=layer_metadata.num_prefills,
|
||||
num_prefill_tokens=tokens[i],
|
||||
head_dim=layer_metadata.head_dim,
|
||||
decode=layer_metadata.decode,
|
||||
prefill=prefill_metadata,
|
||||
))
|
||||
|
||||
sub_attn_metadata_list = []
|
||||
for common_meta, layer_meta in zip(sub_common_metadata, sub_layer_metadata):
|
||||
sub_attn_metadata_dict = {}
|
||||
for key, value in attn_metadata.items():
|
||||
if key == COMMON_METADATA_STR:
|
||||
sub_attn_metadata_dict[key] = common_meta
|
||||
else:
|
||||
sub_attn_metadata_dict[key] = layer_meta
|
||||
sub_attn_metadata_list.append(sub_attn_metadata_dict)
|
||||
return sub_attn_metadata_list
|
||||
elif bs_parts_to_split > 1:
|
||||
common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata)
|
||||
if not hasattr(layer_metadata, 'num_prefills') or layer_metadata.num_prefills is None:
|
||||
raise ValueError("layer_metadata.num_prefills is required")
|
||||
total_batch = layer_metadata.num_prefills
|
||||
batch_sizes, offsets = get_data_num_and_offset(total_batch, bs_parts_to_split)
|
||||
sub_common_metadata, sub_layer_metadata = [], []
|
||||
for i in range(bs_parts_to_split):
|
||||
# query_start_loc tensor
|
||||
start, end = offsets[i], offsets[i] + batch_sizes[i]
|
||||
query_start_loc_tensor = common_metadata.query_start_loc[start:end+1].clone()
|
||||
if i > 0:
|
||||
query_start_loc_tensor -= common_metadata.query_start_loc[start]
|
||||
# block_table
|
||||
block_tables = torch.empty(
|
||||
(batch_sizes[i], 0),
|
||||
dtype=layer_metadata.prefill.block_table.dtype,
|
||||
device=layer_metadata.prefill.block_table.device,
|
||||
)
|
||||
# seq_lens tensor
|
||||
seq_lens_tensor = common_metadata.seq_lens[start:end].clone()
|
||||
# seq_start_loc tensor
|
||||
seq_start_loc_tensor = query_start_loc_tensor
|
||||
# max_query_len scalar
|
||||
max_query_len = seq_lens_tensor.max().item() if seq_lens_tensor.numel() > 0 else 0
|
||||
# num_actual_tokens scalar
|
||||
num_actual_tokens = seq_start_loc_tensor[-1].item()
|
||||
# num_input_tokens scalar
|
||||
num_input_tokens = num_actual_tokens
|
||||
# infer_mode
|
||||
infer_mode = common_metadata.infer_mode
|
||||
# slot_mapping tensor
|
||||
slot_mapping_start = 0
|
||||
for data in sub_common_metadata:
|
||||
slot_mapping_start += data.num_actual_tokens
|
||||
slot_mapping_tensor = layer_metadata.slot_mapping[
|
||||
slot_mapping_start:slot_mapping_start + num_actual_tokens
|
||||
]
|
||||
# update common metadata
|
||||
sub_common_metadata.append(MLUCommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
query_start_loc_cpu=common_metadata.query_start_loc_cpu, # FIXME: split when used
|
||||
seq_lens=seq_lens_tensor,
|
||||
seq_lens_cpu=common_metadata.seq_lens_cpu, # FIXME: split when used
|
||||
num_computed_tokens_cpu=common_metadata.num_computed_tokens_cpu, # FIXME: split when used
|
||||
num_reqs=common_metadata.num_reqs, # FIXME: split when used
|
||||
block_table_tensor=common_metadata.block_table_tensor, # FIXME: split when used
|
||||
slot_mapping=common_metadata.slot_mapping, # FIXME: split when used
|
||||
seq_start_loc=seq_start_loc_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_query_len,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
num_input_tokens=num_input_tokens,
|
||||
infer_mode=infer_mode,
|
||||
num_prefill_query_tokens=num_actual_tokens,
|
||||
num_prefill_kv_tokens=num_actual_tokens,
|
||||
))
|
||||
# update layer_metadata
|
||||
prefill_metadata = FlashMLAPrefillMetadata(
|
||||
block_table=block_tables,
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
max_query_len=max_query_len,
|
||||
chunked_context=None,
|
||||
num_prefills=batch_sizes[i],
|
||||
max_seq_len=max_query_len,
|
||||
)
|
||||
sub_layer_metadata.append(FlashMLAMetadata(
|
||||
num_reqs=batch_sizes[i],
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_query_len,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
num_decodes=layer_metadata.num_decodes, # useless field
|
||||
num_decode_tokens=0, # useless field
|
||||
num_prefills=batch_sizes[i],
|
||||
num_prefill_tokens=num_actual_tokens,
|
||||
head_dim=layer_metadata.head_dim,
|
||||
decode=layer_metadata.decode,
|
||||
prefill=prefill_metadata,
|
||||
))
|
||||
|
||||
sub_attn_metadata_list = []
|
||||
for common_meta, layer_meta in zip(sub_common_metadata, sub_layer_metadata):
|
||||
sub_attn_metadata_dict = {}
|
||||
for key, value in attn_metadata.items():
|
||||
if key == COMMON_METADATA_STR:
|
||||
sub_attn_metadata_dict[key] = common_meta
|
||||
else:
|
||||
sub_attn_metadata_dict[key] = layer_meta
|
||||
sub_attn_metadata_list.append(sub_attn_metadata_dict)
|
||||
return sub_attn_metadata_list
|
||||
|
||||
|
||||
def execute_with_updated_forward_context(
|
||||
vllm_config: VllmConfig,
|
||||
attn_metadata: AttentionMetadata,
|
||||
func: Callable,
|
||||
kwargs: Dict[str, Any],
|
||||
):
|
||||
with set_forward_context(attn_metadata, vllm_config):
|
||||
return func(**kwargs)
|
||||
81
vllm_mlu/model_executor/models/registry.py
Normal file
81
vllm_mlu/model_executor/models/registry.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Type, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.registry import (
|
||||
_LazyRegisteredModel, _RegisteredModel, _ModelRegistry)
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__model_executor__models__registry___ModelRegistry__register_model(
|
||||
self,
|
||||
model_arch: str,
|
||||
model_cls: Union[type[nn.Module], str],
|
||||
) -> None:
|
||||
"""
|
||||
Register an external model to be used in vLLM.
|
||||
|
||||
`model_cls` can be either:
|
||||
|
||||
- A [`torch.nn.Module`][] class directly referencing the model.
|
||||
- A string in the format `<module>:<class>` which can be used to
|
||||
lazily import the model. This is useful to avoid initializing CUDA
|
||||
when importing the model and thus the related error
|
||||
`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
|
||||
"""
|
||||
if not isinstance(model_arch, str):
|
||||
msg = f"`model_arch` should be a string, not a {type(model_arch)}"
|
||||
raise TypeError(msg)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: change mlu models register log level
|
||||
'''
|
||||
if model_arch in self.models:
|
||||
if isinstance(model_cls, str) and "MLU" in model_cls:
|
||||
logger.debug(
|
||||
"Model architecture %s is already registered, and will be "
|
||||
"overwritten by the new model class %s.", model_arch,
|
||||
model_cls)
|
||||
else:
|
||||
logger.warning(
|
||||
"Model architecture %s is already registered, and will be "
|
||||
"overwritten by the new model class %s.", model_arch,
|
||||
model_cls)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
if isinstance(model_cls, str):
|
||||
split_str = model_cls.split(":")
|
||||
if len(split_str) != 2:
|
||||
msg = "Expected a string in the format `<module>:<class>`"
|
||||
raise ValueError(msg)
|
||||
|
||||
model = _LazyRegisteredModel(*split_str)
|
||||
elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
|
||||
model = _RegisteredModel.from_model_cls(model_cls)
|
||||
else:
|
||||
msg = ("`model_cls` should be a string or PyTorch model class, "
|
||||
f"not a {type(model_arch)}")
|
||||
raise TypeError(msg)
|
||||
|
||||
self.models[model_arch] = model
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
_ModelRegistry,
|
||||
_ModelRegistry.register_model,
|
||||
vllm__model_executor__models__registry___ModelRegistry__register_model
|
||||
)
|
||||
67
vllm_mlu/model_executor/models/utils.py
Normal file
67
vllm_mlu/model_executor/models/utils.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
def set_attn_compute_dtype_v1(attn_metadata, dtype: torch.dtype):
|
||||
'''
|
||||
set attn compute_dtype for v1
|
||||
'''
|
||||
if isinstance(attn_metadata, dict):
|
||||
for _, metadata in attn_metadata.items():
|
||||
metadata.compute_dtype = dtype
|
||||
else:
|
||||
metadata.compute_dtype = dtype
|
||||
|
||||
def set_attn_compute_dtype(dtype: torch.dtype):
|
||||
'''
|
||||
set attn compute_dtype.
|
||||
|
||||
TODO: FA may standardize on half precision computation in the future
|
||||
set_attn_compute_dtype might be deprecated and removed
|
||||
'''
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
return
|
||||
set_attn_compute_dtype_v1(attn_metadata, dtype)
|
||||
|
||||
def is_tie_word_embeddings(
|
||||
model_config: ModelConfig,
|
||||
org_tie_word_embeddings: bool
|
||||
) -> bool:
|
||||
'''
|
||||
Vllm language model config for multimodal model may have wrong tie_word_embeddings,
|
||||
for example, InternVL3.5-38B, InternVL3.5-30B-A3B, etc.
|
||||
|
||||
This function is a WorkAround.
|
||||
'''
|
||||
from vllm.lora.utils import get_adapter_absolute_path
|
||||
|
||||
if not model_config.is_multimodal_model:
|
||||
return org_tie_word_embeddings
|
||||
|
||||
model_path = get_adapter_absolute_path(model_config.model)
|
||||
config_path = os.path.join(model_path, "config.json")
|
||||
if not os.path.exists(config_path):
|
||||
return org_tie_word_embeddings
|
||||
|
||||
tie_word_embeddings = org_tie_word_embeddings
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
# first, we find if tie_word_embeddings config is in overall config
|
||||
if config.get("tie_word_embeddings") is not None:
|
||||
tie_word_embeddings = config["tie_word_embeddings"]
|
||||
# then, we find if tie_word_embeddings config is in language model config
|
||||
if (config.get("llm_config") is not None
|
||||
and config["llm_config"].get("tie_word_embeddings") is not None):
|
||||
tie_word_embeddings = config["llm_config"]["tie_word_embeddings"]
|
||||
|
||||
return tie_word_embeddings
|
||||
Reference in New Issue
Block a user