v1.0
This commit is contained in:
225
model_executor/layers/mamba/mamba_utils.py
Normal file
225
model_executor/layers/mamba/mamba_utils.py
Normal file
@@ -0,0 +1,225 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config.cache import MambaDType
|
||||
from vllm.config.model import ModelDType
|
||||
from vllm.distributed import divide
|
||||
from vllm.utils.torch_utils import (
|
||||
STR_DTYPE_TO_TORCH_DTYPE,
|
||||
get_kv_cache_torch_dtype,
|
||||
)
|
||||
|
||||
|
||||
class MambaStateDtypeCalculator:
|
||||
@classmethod
|
||||
def linear_attention_state_dtype(
|
||||
cls,
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
# TODO (tdoublep) requires testing
|
||||
if mamba_cache_dtype == "float32":
|
||||
raise ValueError("fp32 state for minimax is not yet supported")
|
||||
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||
return (state_dtype,)
|
||||
|
||||
@classmethod
|
||||
def mamba1_state_dtype(
|
||||
cls,
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
mamba_ssm_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
return cls._mamba_state_dtype(
|
||||
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def mamba2_state_dtype(
|
||||
cls,
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
mamba_ssm_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
return cls._mamba_state_dtype(
|
||||
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _mamba_state_dtype(
|
||||
cls,
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
mamba_ssm_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||
if mamba_ssm_cache_dtype == "auto":
|
||||
temporal_state_dtype = conv_state_dtype
|
||||
else:
|
||||
temporal_state_dtype = STR_DTYPE_TO_TORCH_DTYPE[mamba_ssm_cache_dtype]
|
||||
|
||||
return (conv_state_dtype, temporal_state_dtype)
|
||||
|
||||
@classmethod
|
||||
def short_conv_state_dtype(
|
||||
cls,
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||
return (conv_state_dtype,)
|
||||
|
||||
@classmethod
|
||||
def gated_delta_net_state_dtype(
|
||||
cls,
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||
return (state_dtype, state_dtype)
|
||||
|
||||
@classmethod
|
||||
def kda_state_dtype(
|
||||
cls,
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
):
|
||||
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||
return (state_dtype, state_dtype, state_dtype, torch.float32)
|
||||
|
||||
|
||||
class MambaStateShapeCalculator:
|
||||
@classmethod
|
||||
def linear_attention_state_shape(
|
||||
cls,
|
||||
num_heads: int,
|
||||
tp_size: int,
|
||||
head_dim: int,
|
||||
) -> tuple[tuple[int, int, int], ...]:
|
||||
state_shape = (num_heads // tp_size, head_dim, head_dim)
|
||||
return (state_shape,)
|
||||
|
||||
@classmethod
|
||||
def mamba1_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
intermediate_size: int,
|
||||
state_size: int,
|
||||
conv_kernel: int,
|
||||
) -> tuple[tuple[int, int], tuple[int, int]]:
|
||||
conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1)
|
||||
|
||||
temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size)
|
||||
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
@classmethod
|
||||
def mamba2_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
intermediate_size: int,
|
||||
n_groups: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
state_size: int,
|
||||
conv_kernel: int,
|
||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||
# if n_groups is not divisible by world_size, need to extend the shards
|
||||
# to ensure all groups needed by a head is sharded along with it
|
||||
n_groups = n_groups + cls.extra_groups_for_head_shards(n_groups, tp_world_size)
|
||||
# heads and n_groups are TP-ed
|
||||
conv_dim = intermediate_size + 2 * n_groups * state_size
|
||||
|
||||
# contiguous along 'dim' axis
|
||||
conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
|
||||
|
||||
# These are not TP-ed as they depend on A, dt_bias, D
|
||||
# - they are typically small
|
||||
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
|
||||
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
@classmethod
|
||||
def short_conv_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
intermediate_size: int,
|
||||
conv_kernel: int,
|
||||
) -> tuple[tuple[int, int]]:
|
||||
conv_dim = divide(intermediate_size, tp_world_size)
|
||||
conv_state_shape = (conv_kernel - 1, conv_dim)
|
||||
return (conv_state_shape,)
|
||||
|
||||
@classmethod
|
||||
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
|
||||
"""Compute the increase in group numbers to account for
|
||||
replication in order to accompany the head shards."""
|
||||
|
||||
# in the case ngoups % tp_size == 0, this will be zero
|
||||
if ngroups % tp_size == 0:
|
||||
return 0
|
||||
|
||||
# for n_groups == 1, this is exactly tp_size - n_groups
|
||||
return tp_size - ngroups
|
||||
|
||||
@classmethod
|
||||
def gated_delta_net_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
num_k_heads: int,
|
||||
num_v_heads: int,
|
||||
head_k_dim: int,
|
||||
head_v_dim: int,
|
||||
conv_kernel_size: int,
|
||||
num_spec: int = 0,
|
||||
):
|
||||
conv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads
|
||||
conv_state_shape = (
|
||||
divide(conv_dim, tp_world_size),
|
||||
conv_kernel_size - 1 + num_spec,
|
||||
)
|
||||
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
|
||||
temporal_state_shape = (
|
||||
divide(num_v_heads, tp_world_size),
|
||||
head_k_dim,
|
||||
head_v_dim,
|
||||
)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
@classmethod
|
||||
def kda_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
num_k_heads: int | None = None,
|
||||
head_k_dim: int | None = None,
|
||||
conv_kernel_size: int = 4,
|
||||
num_spec: int = 0,
|
||||
) -> tuple[tuple[int, int], tuple[int, int], tuple[int, int], tuple[int, int, int]]:
|
||||
if num_k_heads is None:
|
||||
num_k_heads = num_heads
|
||||
if head_k_dim is None:
|
||||
head_k_dim = head_dim
|
||||
|
||||
proj_size = num_heads * head_dim
|
||||
proj_k_size = num_k_heads * head_k_dim
|
||||
|
||||
conv_state_shape = (divide(proj_size, tp_world_size), conv_kernel_size - 1)
|
||||
conv_state_k_shape = (divide(proj_k_size, tp_world_size), conv_kernel_size - 1)
|
||||
recurrent_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim)
|
||||
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
conv_state_k_shape = conv_state_k_shape[1], conv_state_k_shape[0]
|
||||
return (
|
||||
conv_state_shape,
|
||||
conv_state_k_shape,
|
||||
conv_state_k_shape,
|
||||
recurrent_state_shape,
|
||||
)
|
||||
Reference in New Issue
Block a user