model: Support Hybrid Mamba2 NemotronHForCausalLM (nvidia/NVIDIA-Nemotron-Nano-9B-v2) (#10909)
Signed-off-by: Netanel Haber <nhaber@nvidia.com>
This commit is contained in:
375
test/srt/layers/attention/mamba/test_causal_conv1d.py
Normal file
375
test/srt/layers/attention/mamba/test_causal_conv1d.py
Normal file
@@ -0,0 +1,375 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/tests/kernels/mamba/test_causal_conv1d.py
|
||||
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
|
||||
PAD_SLOT_ID,
|
||||
causal_conv1d_fn,
|
||||
causal_conv1d_update,
|
||||
)
|
||||
|
||||
|
||||
def causal_conv1d_ref(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
initial_states: Optional[torch.Tensor] = None,
|
||||
return_final_states: bool = False,
|
||||
final_states_out: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen)
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
initial_states: (batch, dim, width - 1)
|
||||
final_states_out: (batch, dim, width - 1)
|
||||
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
dtype_in = x.dtype
|
||||
x = x.to(weight.dtype)
|
||||
seqlen = x.shape[-1]
|
||||
dim, width = weight.shape
|
||||
if initial_states is None:
|
||||
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
|
||||
else:
|
||||
x = torch.cat([initial_states, x], dim=-1)
|
||||
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
|
||||
out = out[..., :seqlen]
|
||||
if return_final_states:
|
||||
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
|
||||
dtype_in
|
||||
) # (batch, dim, width - 1)
|
||||
if final_states_out is not None:
|
||||
final_states_out.copy_(final_states)
|
||||
else:
|
||||
final_states_out = final_states
|
||||
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
||||
return (out, None) if not return_final_states else (out, final_states_out)
|
||||
|
||||
|
||||
def causal_conv1d_update_ref(
|
||||
x, conv_state, weight, bias=None, activation=None, cache_seqlens=None
|
||||
):
|
||||
"""
|
||||
x: (batch, dim) or (batch, dim, seqlen)
|
||||
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
cache_seqlens: (batch,), dtype int32.
|
||||
If not None, the conv_state is treated as a circular buffer.
|
||||
The conv_state will be updated by copying x to the
|
||||
conv_state starting at the index
|
||||
@cache_seqlens % state_len before performing the convolution.
|
||||
|
||||
out: (batch, dim) or (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
dtype_in = x.dtype
|
||||
unsqueeze = x.dim() == 2
|
||||
if unsqueeze:
|
||||
x = x.unsqueeze(-1)
|
||||
batch, dim, seqlen = x.shape
|
||||
width = weight.shape[1]
|
||||
state_len = conv_state.shape[-1]
|
||||
assert conv_state.shape == (batch, dim, state_len)
|
||||
assert weight.shape == (dim, width)
|
||||
if cache_seqlens is None:
|
||||
x_new = torch.cat([conv_state, x], dim=-1).to(
|
||||
weight.dtype
|
||||
) # (batch, dim, state_len + seqlen)
|
||||
conv_state.copy_(x_new[:, :, -state_len:])
|
||||
else:
|
||||
width_idx = torch.arange(
|
||||
-(width - 1), 0, dtype=torch.long, device=x.device
|
||||
).unsqueeze(0) + cache_seqlens.unsqueeze(1)
|
||||
width_idx = (
|
||||
torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
||||
)
|
||||
x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
|
||||
copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(
|
||||
0
|
||||
) + cache_seqlens.unsqueeze(1)
|
||||
copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
||||
conv_state.scatter_(2, copy_idx, x)
|
||||
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[
|
||||
:, :, -seqlen:
|
||||
]
|
||||
if unsqueeze:
|
||||
out = out.squeeze(-1)
|
||||
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
|
||||
@pytest.mark.parametrize("silu_activation", [True])
|
||||
@pytest.mark.parametrize("has_bias", [True])
|
||||
def causal_conv1d_opcheck_fn(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
cu_seq_len: Optional[torch.Tensor] = None,
|
||||
cache_indices: Optional[torch.Tensor] = None,
|
||||
has_initial_state: Optional[torch.Tensor] = None,
|
||||
conv_states: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen)
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
seq_idx: (batch, seqlen)
|
||||
initial_states: (batch, dim, width - 1)
|
||||
final_states_out: (batch, dim, width - 1), to be written to
|
||||
activation: either None or "silu" or "swish"
|
||||
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
bias = bias.contiguous() if bias is not None else None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [False, True])
|
||||
@pytest.mark.parametrize("has_bias", [False, True])
|
||||
@pytest.mark.parametrize("seqlen", [1])
|
||||
@pytest.mark.parametrize("width", [4])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, itype):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA device not available")
|
||||
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
# set seed
|
||||
torch.manual_seed(0)
|
||||
batch = 2
|
||||
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
|
||||
x_ref = x.clone()
|
||||
conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype)
|
||||
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
conv_state_ref = conv_state.detach().clone()
|
||||
activation = None if not silu_activation else "silu"
|
||||
out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation)
|
||||
out_ref = causal_conv1d_update_ref(
|
||||
x_ref, conv_state_ref, weight, bias, activation=activation
|
||||
)
|
||||
|
||||
assert torch.equal(conv_state, conv_state_ref)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [False, True])
|
||||
@pytest.mark.parametrize("has_bias", [False, True])
|
||||
@pytest.mark.parametrize("seqlen", [1, 3])
|
||||
@pytest.mark.parametrize("width", [3, 4])
|
||||
@pytest.mark.parametrize("dim", [2048 + 16, 4096])
|
||||
# tests correctness in case subset of the sequences are padded
|
||||
@pytest.mark.parametrize("with_padding", [True, False])
|
||||
@pytest.mark.parametrize("batch_size", [3])
|
||||
def test_causal_conv1d_update_with_batch_gather(
|
||||
batch_size, with_padding, dim, width, seqlen, has_bias, silu_activation, itype
|
||||
):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA device not available")
|
||||
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
|
||||
# set seed
|
||||
torch.manual_seed(0)
|
||||
|
||||
padding = 5 if with_padding else 0
|
||||
padded_batch_size = batch_size + padding
|
||||
# total_entries = number of cache line
|
||||
total_entries = 10 * batch_size
|
||||
|
||||
# x will be (batch, dim, seqlen) with contiguous along dim-axis
|
||||
x = torch.randn(
|
||||
padded_batch_size, seqlen, dim, device=device, dtype=itype
|
||||
).transpose(1, 2)
|
||||
|
||||
x_ref = x.clone()
|
||||
|
||||
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||
dtype=torch.int32, device=device
|
||||
)
|
||||
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
|
||||
unused_states_bool[conv_state_indices] = False
|
||||
padded_state_indices = torch.concat(
|
||||
[
|
||||
conv_state_indices,
|
||||
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# conv_state will be (cache_lines, dim, state_len)
|
||||
# with contiguous along dim-axis
|
||||
conv_state = torch.randn(
|
||||
total_entries, width - 1, dim, device=device, dtype=itype
|
||||
).transpose(1, 2)
|
||||
|
||||
conv_state_for_padding_test = conv_state.clone()
|
||||
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
|
||||
activation = None if not silu_activation else "silu"
|
||||
|
||||
out = causal_conv1d_update(
|
||||
x,
|
||||
conv_state,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation,
|
||||
conv_state_indices=padded_state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
)
|
||||
out_ref = causal_conv1d_update_ref(
|
||||
x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation
|
||||
)
|
||||
|
||||
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
|
||||
assert torch.equal(
|
||||
conv_state[unused_states_bool], conv_state_for_padding_test[unused_states_bool]
|
||||
)
|
||||
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [True])
|
||||
@pytest.mark.parametrize("has_bias", [True])
|
||||
@pytest.mark.parametrize("width", [4])
|
||||
@pytest.mark.parametrize("seqlen", [8, 30, 249, 2049, 4096])
|
||||
@pytest.mark.parametrize("dim", [64, 4096])
|
||||
@pytest.mark.parametrize("with_padding", [True, False])
|
||||
@pytest.mark.parametrize("batch", [4, 10])
|
||||
def test_causal_conv1d_varlen(
|
||||
batch, with_padding, dim, seqlen, width, has_bias, silu_activation, itype
|
||||
):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA device not available")
|
||||
|
||||
device = "cuda"
|
||||
torch.cuda.empty_cache()
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
# set seed
|
||||
torch.manual_seed(0)
|
||||
seqlens = []
|
||||
batch_size = batch
|
||||
padding = 3 if with_padding else 0
|
||||
padded_batch_size = batch_size + padding
|
||||
nsplits = padded_batch_size - 1
|
||||
|
||||
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
|
||||
|
||||
seqlens.append(
|
||||
torch.diff(
|
||||
torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])
|
||||
).tolist()
|
||||
)
|
||||
assert sum(seqlens[-1]) == seqlen
|
||||
assert all(s > 0 for s in seqlens[-1])
|
||||
|
||||
total_entries = batch_size * 10
|
||||
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
|
||||
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0)
|
||||
x = rearrange(
|
||||
torch.randn(1, seqlen, 4096 + dim + 64, device=device, dtype=itype),
|
||||
"b s d -> b d s",
|
||||
)[:, 4096 : 4096 + dim, :]
|
||||
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
x_ref = x.clone()
|
||||
weight_ref = weight.clone()
|
||||
bias_ref = bias.clone() if bias is not None else None
|
||||
activation = None if not silu_activation else "silu"
|
||||
final_states = torch.randn(
|
||||
total_entries, width - 1, dim, device=x.device, dtype=x.dtype
|
||||
).transpose(1, 2)
|
||||
final_states_ref = final_states.clone()
|
||||
has_initial_states = torch.randint(
|
||||
0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=x.device
|
||||
)
|
||||
state_indices = torch.randperm(total_entries, dtype=torch.int32, device=x.device)[
|
||||
:batch_size
|
||||
]
|
||||
padded_state_indices = torch.concat(
|
||||
[
|
||||
state_indices,
|
||||
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
out = causal_conv1d_fn(
|
||||
x.squeeze(0),
|
||||
weight,
|
||||
bias=bias,
|
||||
conv_states=final_states,
|
||||
query_start_loc=cumsum.cuda(),
|
||||
seq_lens_cpu=torch.tensor(seqlens[0]),
|
||||
cache_indices=padded_state_indices,
|
||||
has_initial_state=has_initial_states,
|
||||
activation=activation,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
)
|
||||
|
||||
out_ref = []
|
||||
out_ref_b = []
|
||||
|
||||
splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
|
||||
for i in range(len(seqlens[0])):
|
||||
x_s = [v[i].unsqueeze(0) for v in splits][0]
|
||||
if padded_state_indices[i] == PAD_SLOT_ID:
|
||||
continue
|
||||
out_ref_b.append(
|
||||
causal_conv1d_ref(
|
||||
x_s,
|
||||
weight_ref,
|
||||
bias_ref,
|
||||
activation=activation,
|
||||
return_final_states=True,
|
||||
final_states_out=final_states_ref[padded_state_indices[i]].unsqueeze(0),
|
||||
initial_states=(
|
||||
final_states_ref[padded_state_indices[i]].unsqueeze(0)
|
||||
if has_initial_states[i]
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
|
||||
out_ref_tensor = torch.cat(out_ref, dim=0)
|
||||
|
||||
assert torch.allclose(
|
||||
final_states[state_indices],
|
||||
final_states_ref[state_indices],
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
)
|
||||
unpadded_out = out[:, : out_ref_tensor.shape[-1]]
|
||||
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
|
||||
138
test/srt/layers/attention/mamba/test_mamba2_mixer.py
Normal file
138
test/srt/layers/attention/mamba/test_mamba2_mixer.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/2c58742dff8613a3bd7496f2008ce927e18d38d1/tests/kernels/mamba/test_mamba_mixer2.py
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (
|
||||
update_environment_variables,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import (
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
|
||||
NUM_GPUS = 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seq_len", [128])
|
||||
@pytest.mark.parametrize(
|
||||
"hidden_size_n_groups",
|
||||
[
|
||||
(64, 1), # hidden_size be divisible by num_gpus
|
||||
(100, 4), # and n_groups must divide hidden_size
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||
def test_mixer2_gated_norm_multi_gpu(
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size_n_groups: tuple[int, int],
|
||||
dtype: torch.dtype,
|
||||
device: str = "cuda",
|
||||
):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA device not available")
|
||||
|
||||
assert torch.cuda.device_count() == NUM_GPUS
|
||||
|
||||
hidden_size, n_groups = hidden_size_n_groups
|
||||
num_processes = NUM_GPUS
|
||||
|
||||
def run_torch_spawn(fn, nprocs):
|
||||
# need to use torch.mp.spawn otherwise will have problems with
|
||||
# torch.distributed and cuda
|
||||
torch.multiprocessing.spawn(
|
||||
fn,
|
||||
args=(
|
||||
num_processes,
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
n_groups,
|
||||
dtype,
|
||||
device,
|
||||
),
|
||||
nprocs=nprocs,
|
||||
)
|
||||
|
||||
run_torch_spawn(mixer2_gated_norm_tensor_parallel, NUM_GPUS)
|
||||
|
||||
|
||||
def mixer2_gated_norm_tensor_parallel(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
n_groups: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
update_environment_variables(
|
||||
{
|
||||
"RANK": str(local_rank),
|
||||
"LOCAL_RANK": str(local_rank),
|
||||
"WORLD_SIZE": str(world_size),
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
}
|
||||
)
|
||||
|
||||
# initialize distributed
|
||||
init_distributed_environment(
|
||||
world_size=world_size, rank=local_rank, local_rank=local_rank
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# create random weights an inputs
|
||||
weight = torch.rand((hidden_size,), dtype=dtype, device=device)
|
||||
hidden_states = torch.randn(batch_size, seq_len, hidden_size)
|
||||
gate_states = torch.randn(batch_size, seq_len, hidden_size)
|
||||
|
||||
import sglang.srt.layers.attention.mamba.mixer2_rms_norm_gated as m2
|
||||
import sglang.srt.model_loader.weight_utils as wu
|
||||
|
||||
# Convenience: Avoid calling initialize_dp_attention
|
||||
with patch.object(wu, "get_attention_tp_rank", return_value=local_rank):
|
||||
# create gated-norm with TP
|
||||
mixer = m2.Mixer2RMSNormGated(
|
||||
full_hidden_size=hidden_size,
|
||||
full_n_groups=n_groups,
|
||||
)
|
||||
mixer.weight.weight_loader(mixer.weight, weight)
|
||||
|
||||
with (
|
||||
patch.object(m2, "get_tensor_model_parallel_world_size", return_value=1),
|
||||
patch.object(m2, "get_tensor_model_parallel_rank", return_value=0),
|
||||
):
|
||||
# create gated-norm without TP to compute reference
|
||||
mixer_single_gpu = m2.Mixer2RMSNormGated(
|
||||
full_hidden_size=hidden_size,
|
||||
full_n_groups=n_groups,
|
||||
)
|
||||
# assign weight to single-gpu mixer
|
||||
mixer_single_gpu.weight.data = weight
|
||||
|
||||
# generate and compare
|
||||
N = hidden_size // world_size
|
||||
output = mixer(
|
||||
hidden_states[..., local_rank * N : (local_rank + 1) * N],
|
||||
gate_states[..., local_rank * N : (local_rank + 1) * N],
|
||||
)
|
||||
ref_output = mixer_single_gpu(hidden_states, gate_states)
|
||||
torch.testing.assert_close(
|
||||
output,
|
||||
ref_output[..., local_rank * N : (local_rank + 1) * N],
|
||||
atol=5e-3,
|
||||
rtol=1e-3,
|
||||
)
|
||||
291
test/srt/layers/attention/mamba/test_mamba_ssm.py
Normal file
291
test/srt/layers/attention/mamba/test_mamba_ssm.py
Normal file
@@ -0,0 +1,291 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/633f943e30a4444d890d26b81850f7217736f840/tests/kernels/mamba/test_mamba_ssm_ssd.py
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import PAD_SLOT_ID
|
||||
from sglang.srt.layers.attention.mamba.ops import selective_state_update
|
||||
|
||||
|
||||
def selective_state_update_ref(
|
||||
state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
|
||||
):
|
||||
"""
|
||||
Argument:
|
||||
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
x: (batch, dim) or (batch, nheads, dim)
|
||||
dt: (batch, dim) or (batch, nheads, dim)
|
||||
A: (dim, dstate) or (nheads, dim, dstate)
|
||||
B: (batch, dstate) or (batch, ngroups, dstate)
|
||||
C: (batch, dstate) or (batch, ngroups, dstate)
|
||||
D: (dim,) or (nheads, dim)
|
||||
z: (batch, dim) or (batch, nheads, dim)
|
||||
dt_bias: (dim,) or (nheads, dim)
|
||||
Return:
|
||||
out: (batch, dim) or (batch, nheads, dim)
|
||||
"""
|
||||
has_heads = state.dim() > 3
|
||||
if state.dim() == 3:
|
||||
state = state.unsqueeze(1)
|
||||
if x.dim() == 2:
|
||||
x = x.unsqueeze(1)
|
||||
if dt.dim() == 2:
|
||||
dt = dt.unsqueeze(1)
|
||||
if A.dim() == 2:
|
||||
A = A.unsqueeze(0)
|
||||
if B.dim() == 2:
|
||||
B = B.unsqueeze(1)
|
||||
if C.dim() == 2:
|
||||
C = C.unsqueeze(1)
|
||||
if D is not None and D.dim() == 1:
|
||||
D = D.unsqueeze(0)
|
||||
if z is not None and z.dim() == 2:
|
||||
z = z.unsqueeze(1)
|
||||
if dt_bias is not None and dt_bias.dim() == 1:
|
||||
dt_bias = dt_bias.unsqueeze(0)
|
||||
batch, nheads, dim, dstate = state.shape
|
||||
assert x.shape == (batch, nheads, dim)
|
||||
assert dt.shape == x.shape
|
||||
assert A.shape == (nheads, dim, dstate)
|
||||
ngroups = B.shape[1]
|
||||
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
||||
assert B.shape == (batch, ngroups, dstate)
|
||||
assert C.shape == B.shape
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, dim)
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads, dim)
|
||||
dt = dt + dt_bias
|
||||
dt = F.softplus(dt) if dt_softplus else dt
|
||||
dA = torch.exp(
|
||||
rearrange(dt, "b h d -> b h d 1") * A
|
||||
) # (batch, nheads, dim, dstate)
|
||||
B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
|
||||
C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
|
||||
dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
|
||||
B, "b h n -> b h 1 n"
|
||||
) # (batch, nheads, dim, dstate)
|
||||
state.copy_(
|
||||
state * dA + dB * rearrange(x, "b h d -> b h d 1")
|
||||
) # (batch, dim, dstate
|
||||
out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
|
||||
if D is not None:
|
||||
out += (x * D).to(out.dtype)
|
||||
out = (out if z is None else out * F.silu(z)).to(x.dtype)
|
||||
if not has_heads:
|
||||
out = out.squeeze(1)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("has_z", [False, True])
|
||||
@pytest.mark.parametrize("dstate", [16, 32, 64])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
def test_selective_state_update(dim, dstate, has_z, itype):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA device not available")
|
||||
|
||||
device = "cuda"
|
||||
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
if torch.version.hip:
|
||||
atol *= 2
|
||||
# set seed
|
||||
torch.manual_seed(0)
|
||||
batch_size = 1
|
||||
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
|
||||
x = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||
out = torch.empty_like(x)
|
||||
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||
dt_bias = torch.rand(dim, device=device) - 4.0
|
||||
A = -torch.rand(dim, dstate, device=device) - 1.0
|
||||
B = torch.randn(batch_size, dstate, device=device)
|
||||
C = torch.randn(batch_size, dstate, device=device)
|
||||
D = torch.randn(dim, device=device)
|
||||
z = torch.randn_like(x) if has_z else None
|
||||
state_ref = state.detach().clone()
|
||||
selective_state_update(
|
||||
state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True, out=out
|
||||
)
|
||||
out_ref = selective_state_update_ref(
|
||||
state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True
|
||||
)
|
||||
|
||||
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("has_z", [True])
|
||||
@pytest.mark.parametrize("dstate", [16, 32, 64])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
# tests correctness in case subset of the sequences are padded
|
||||
@pytest.mark.parametrize("with_padding", [True, False])
|
||||
def test_selective_state_update_with_batch_indices(
|
||||
with_padding, dim, dstate, has_z, itype
|
||||
):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA device not available")
|
||||
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-1, 1e-1
|
||||
if torch.version.hip:
|
||||
atol *= 2
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 3
|
||||
padding = 5 if with_padding else 0
|
||||
padded_batch_size = batch_size + padding
|
||||
total_entries = 10 * batch_size
|
||||
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
|
||||
state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||
dtype=torch.int32, device=device
|
||||
)
|
||||
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
|
||||
unused_states_bool[state_indices] = False
|
||||
padded_state_indices = torch.concat(
|
||||
[
|
||||
state_indices,
|
||||
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
x = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
|
||||
out = torch.empty_like(x)
|
||||
dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
|
||||
dt_bias = torch.rand(dim, device=device) - 4.0
|
||||
A = -torch.rand(dim, dstate, device=device) - 1.0
|
||||
B = torch.randn(padded_batch_size, dstate, device=device)
|
||||
C = torch.randn(padded_batch_size, dstate, device=device)
|
||||
D = torch.randn(dim, device=device)
|
||||
z = torch.randn_like(x) if has_z else None
|
||||
state_ref = state[state_indices, :].clone()
|
||||
state_before = state.clone()
|
||||
selective_state_update(
|
||||
state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=padded_state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
out=out,
|
||||
)
|
||||
out_ref = selective_state_update_ref(
|
||||
state_ref,
|
||||
x[:batch_size],
|
||||
dt[:batch_size],
|
||||
A,
|
||||
B[:batch_size],
|
||||
C[:batch_size],
|
||||
D=D,
|
||||
z=z[:batch_size],
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
)
|
||||
|
||||
print("Output diff max", (out[:batch_size] - out_ref).max())
|
||||
print("Output diff mean", (out[:batch_size] - out_ref).mean())
|
||||
print("Output state diff max", (state[state_indices, :] - state_ref).max())
|
||||
print("Output state diff mean", (state[state_indices, :] - state_ref).mean())
|
||||
# test padded entries stay the same
|
||||
if with_padding:
|
||||
assert torch.equal(state_before[unused_states_bool], state[unused_states_bool])
|
||||
assert torch.equal(x[batch_size + 1 :], x[batch_size + 1 :])
|
||||
assert torch.equal(dt[batch_size + 1 :], dt[batch_size + 1 :])
|
||||
assert torch.equal(B[batch_size + 1 :], B[batch_size + 1 :])
|
||||
assert torch.equal(C[batch_size + 1 :], C[batch_size + 1 :])
|
||||
|
||||
# test "real" entries
|
||||
assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("has_z", [False, True])
|
||||
@pytest.mark.parametrize("tie_hdim", [False, True])
|
||||
@pytest.mark.parametrize("ngroups", [1, 2, 4])
|
||||
@pytest.mark.parametrize("dstate", [16, 32, 64])
|
||||
@pytest.mark.parametrize("dim", [2048, 4096])
|
||||
def test_selective_state_update_with_heads_with_batch_indices(
|
||||
dim, dstate, ngroups, has_z, tie_hdim, itype
|
||||
):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA device not available")
|
||||
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-1, 1e-1
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 3
|
||||
headdim = 64
|
||||
nheads = dim // headdim
|
||||
|
||||
total_entries = 10 * batch_size
|
||||
state = torch.randn(
|
||||
total_entries, nheads, headdim, dstate, dtype=itype, device=device
|
||||
)
|
||||
state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||
dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
|
||||
out = torch.empty_like(x)
|
||||
if not tie_hdim:
|
||||
dt = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
|
||||
dt_bias = torch.rand(nheads, headdim, device=device) - 4.0
|
||||
A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0
|
||||
D = torch.randn(nheads, headdim, device=device)
|
||||
else:
|
||||
dt = repeat(
|
||||
torch.randn(batch_size, nheads, device=device, dtype=itype),
|
||||
"b h -> b h p",
|
||||
p=headdim,
|
||||
)
|
||||
dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, "h -> h p", p=headdim)
|
||||
A = repeat(
|
||||
-torch.rand(nheads, device=device) - 1.0, "h -> h p n", p=headdim, n=dstate
|
||||
)
|
||||
D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim)
|
||||
B = torch.randn(batch_size, ngroups, dstate, device=device)
|
||||
C = torch.randn(batch_size, ngroups, dstate, device=device)
|
||||
z = torch.randn_like(x) if has_z else None
|
||||
state_ref = state[state_indices, :].detach().clone()
|
||||
selective_state_update(
|
||||
state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
out=out,
|
||||
)
|
||||
out_ref = selective_state_update_ref(
|
||||
state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True
|
||||
)
|
||||
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
581
test/srt/layers/attention/mamba/test_mamba_ssm_ssd.py
Normal file
581
test/srt/layers/attention/mamba/test_mamba_ssm_ssd.py
Normal file
@@ -0,0 +1,581 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/633f943e30a4444d890d26b81850f7217736f840/tests/kernels/mamba/test_mamba_ssm_ssd.py
|
||||
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from sglang.srt.layers.attention.mamba.mamba2_metadata import Mamba2Metadata
|
||||
from sglang.srt.layers.attention.mamba.ops import mamba_chunk_scan_combined
|
||||
|
||||
# Added by the IBM Team, 2024
|
||||
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py
|
||||
|
||||
# TODO: These take a long time to run - we should cut down on some of the parameterized matrix.
|
||||
|
||||
|
||||
# this is the segsum implementation taken from above
|
||||
def segsum(x):
|
||||
"""Calculates segment sum."""
|
||||
T = x.size(-1)
|
||||
x = repeat(x, "... d -> ... d e", e=T)
|
||||
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
|
||||
x = x.masked_fill(~mask, 0)
|
||||
x_segsum = torch.cumsum(x, dim=-2)
|
||||
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
|
||||
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
|
||||
return x_segsum
|
||||
|
||||
|
||||
def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
|
||||
"""
|
||||
Arguments:
|
||||
X: (batch, length, n_heads, d_head)
|
||||
A: (batch, length, n_heads)
|
||||
B: (batch, length, n_heads, d_state)
|
||||
C: (batch, length, n_heads, d_state)
|
||||
Return:
|
||||
Y: (batch, length, n_heads, d_head)
|
||||
"""
|
||||
assert X.dtype == A.dtype == B.dtype == C.dtype
|
||||
assert X.shape[1] % block_len == 0
|
||||
|
||||
# Rearrange into blocks/chunks
|
||||
X, A, B, C = (
|
||||
rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)
|
||||
)
|
||||
|
||||
A = rearrange(A, "b c l h -> b h c l")
|
||||
A_cumsum = torch.cumsum(A, dim=-1)
|
||||
|
||||
# 1. Compute the output for each intra-chunk (diagonal blocks)
|
||||
L = torch.exp(segsum(A))
|
||||
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
|
||||
|
||||
# 2. Compute the state for each intra-chunk
|
||||
# (right term of low-rank factorization of off-diagonal blocks; B terms)
|
||||
decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
|
||||
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
|
||||
|
||||
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at
|
||||
# chunk boundaries
|
||||
# (middle term of factorization of off-diag blocks; A terms)
|
||||
if initial_states is None:
|
||||
initial_states = torch.zeros_like(states[:, :1])
|
||||
states = torch.cat([initial_states, states], dim=1)
|
||||
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
|
||||
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
|
||||
states, final_state = new_states[:, :-1], new_states[:, -1]
|
||||
|
||||
# 4. Compute state -> output conversion per chunk
|
||||
# (left term of low-rank factorization of off-diagonal blocks; C terms)
|
||||
state_decay_out = torch.exp(A_cumsum)
|
||||
Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out)
|
||||
|
||||
# Add output of intra-chunk and inter-chunk terms
|
||||
# (diagonal and off-diagonal blocks)
|
||||
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
|
||||
return Y, final_state
|
||||
|
||||
|
||||
def generate_random_inputs(batch_size, seqlen, n_heads, d_head, itype, device="cuda"):
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA device not available")
|
||||
|
||||
torch.manual_seed(0)
|
||||
A = -torch.exp(torch.rand(n_heads, dtype=itype, device=device))
|
||||
dt = F.softplus(
|
||||
torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) - 4
|
||||
)
|
||||
X = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device)
|
||||
B = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device)
|
||||
C = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device)
|
||||
|
||||
return A, dt, X, B, C
|
||||
|
||||
|
||||
def generate_continuous_batched_examples(
|
||||
example_lens_by_batch,
|
||||
num_examples,
|
||||
full_length,
|
||||
last_taken,
|
||||
exhausted,
|
||||
n_heads,
|
||||
d_head,
|
||||
itype,
|
||||
device="cuda",
|
||||
return_naive_ref=True,
|
||||
):
|
||||
|
||||
# this function generates a random examples of certain length
|
||||
# and then cut according to "example_lens_by_batch" and feed
|
||||
# them in continuous batches to the kernels.
|
||||
# If if return_naive_ref=True, the naive torch implementation
|
||||
# ssd_minimal_discrete will be used to compute and return
|
||||
# reference output.
|
||||
|
||||
# generate the full-length example
|
||||
A, dt, X, B, C = generate_random_inputs(
|
||||
num_examples, full_length, n_heads, d_head, itype
|
||||
)
|
||||
|
||||
if return_naive_ref:
|
||||
Y_min, final_state_min = ssd_minimal_discrete(
|
||||
X * dt.unsqueeze(-1), A * dt, B, C, block_len=full_length // 4
|
||||
)
|
||||
|
||||
# internal function that outputs a cont batch of examples
|
||||
# given a tuple of lengths for each example in the batch
|
||||
# e.g., example_lens=(8, 4) means take 8 samples from first eg,
|
||||
# 4 examples from second eg, etc
|
||||
def get_continuous_batch(example_lens: tuple[int, ...]):
|
||||
|
||||
indices = []
|
||||
for i, x in enumerate(example_lens):
|
||||
c = last_taken.get(i, 0)
|
||||
indices.append((c, c + x))
|
||||
last_taken[i] = (c + x) % full_length
|
||||
exhausted[i] = last_taken[i] == 0
|
||||
|
||||
return (
|
||||
torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices)]).unsqueeze(0)
|
||||
for x in (dt, X, B, C)
|
||||
)
|
||||
|
||||
# internal function that maps "n" to the appropriate right boundary
|
||||
# value when forming continuous batches from examples of length given
|
||||
# by "full_length".
|
||||
# - e.g., when n > full_length, returns n % full_length
|
||||
# when n == full_length, returns full_length
|
||||
def end_boundary(n: int):
|
||||
return n - ((n - 1) // full_length) * full_length
|
||||
|
||||
IND_E = None
|
||||
for spec in example_lens_by_batch:
|
||||
|
||||
# get the (maybe partial) example seen in this cont batch
|
||||
dt2, X2, B2, C2 = get_continuous_batch(spec)
|
||||
|
||||
# get the metadata
|
||||
cu_seqlens = torch.tensor((0,) + spec, device=device).cumsum(dim=0)
|
||||
seq_idx = torch.zeros(
|
||||
cu_seqlens[-1], dtype=torch.int32, device=cu_seqlens.device
|
||||
)
|
||||
for i, (srt, end) in enumerate(
|
||||
zip(
|
||||
cu_seqlens,
|
||||
cu_seqlens[1:],
|
||||
)
|
||||
):
|
||||
seq_idx[srt:end] = i
|
||||
|
||||
# for cont batch
|
||||
if IND_E is None:
|
||||
IND_S = [0 for _ in range(len(spec))]
|
||||
else:
|
||||
IND_S = [x % full_length for x in IND_E]
|
||||
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]
|
||||
|
||||
yield (
|
||||
(
|
||||
[Y_min[s, IND_S[s] : IND_E[s]] for s in range(num_examples)]
|
||||
if return_naive_ref
|
||||
else None
|
||||
),
|
||||
cu_seqlens,
|
||||
seq_idx.unsqueeze(0),
|
||||
(A, dt2, X2, B2, C2),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32])
|
||||
@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128])
|
||||
@pytest.mark.parametrize("seq_len_chunk_size", [(112, 16), (128, 32)])
|
||||
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, itype):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA device not available")
|
||||
|
||||
# this tests the kernels on a single example (no batching)
|
||||
|
||||
# TODO: the bfloat16 case requires higher thresholds. To be investigated
|
||||
|
||||
if itype == torch.bfloat16:
|
||||
atol, rtol = 5e-2, 5e-2
|
||||
else:
|
||||
atol, rtol = 8e-3, 5e-3
|
||||
|
||||
# set seed
|
||||
batch_size = 1 # batch_size
|
||||
# ssd_minimal_discrete requires chunk_size divide seqlen
|
||||
# - this is only required for generating the reference seqs,
|
||||
# it is not an operational limitation.
|
||||
seqlen, chunk_size = seq_len_chunk_size
|
||||
|
||||
A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads, d_head, itype)
|
||||
|
||||
Y_min, final_state_min = ssd_minimal_discrete(
|
||||
X * dt.unsqueeze(-1), A * dt, B, C, chunk_size
|
||||
)
|
||||
Y = torch.empty_like(X)
|
||||
final_state = mamba_chunk_scan_combined(
|
||||
X, dt, A, B, C, chunk_size, D=None, return_final_states=True, out=Y
|
||||
)
|
||||
|
||||
# just test the last in sequence
|
||||
torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol)
|
||||
|
||||
# just test the last head
|
||||
# NOTE, in the kernel we always cast states to fp32
|
||||
torch.testing.assert_close(
|
||||
final_state[:, -1],
|
||||
final_state_min[:, -1].to(torch.float32),
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.float16])
|
||||
@pytest.mark.parametrize("n_heads", [4, 8, 13])
|
||||
@pytest.mark.parametrize("d_head", [5, 16, 21, 32])
|
||||
@pytest.mark.parametrize(
|
||||
"seq_len_chunk_size_cases",
|
||||
[
|
||||
# small-ish chunk_size (8)
|
||||
(64, 8, 2, [(64, 32), (64, 32)]),
|
||||
(64, 8, 2, [(32, 32), (32, 32), (32, 32)]),
|
||||
(64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary
|
||||
(
|
||||
64,
|
||||
8,
|
||||
2,
|
||||
[(4, 4), (4, 4), (4, 4), (4, 4)],
|
||||
), # chunk_size larger than cont batches
|
||||
(
|
||||
64,
|
||||
8,
|
||||
5,
|
||||
[
|
||||
(64, 32, 16, 8, 8),
|
||||
(8, 16, 32, 16, 8),
|
||||
(8, 8, 16, 32, 16),
|
||||
],
|
||||
), # mode examples with varied lengths
|
||||
# large-ish chunk_size (256)
|
||||
(64, 256, 1, [(5,), (1,), (1,), (1,)]), # irregular sizes with small sequences
|
||||
(
|
||||
64,
|
||||
256,
|
||||
2,
|
||||
[(5, 30), (1, 2), (1, 2), (1, 2)],
|
||||
), # irregular sizes with small sequences
|
||||
# we also need to test some large seqlen
|
||||
# to catch errors with init states decay
|
||||
(768, 128, 2, [(138, 225), (138, 225)]),
|
||||
],
|
||||
)
|
||||
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, itype):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA device not available")
|
||||
|
||||
# this test with multiple examples in a continuous batch
|
||||
# (i.e. chunked prefill)
|
||||
|
||||
seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
|
||||
|
||||
# This test can have larger error for longer sequences
|
||||
if seqlen > 256:
|
||||
atol, rtol = 1e-2, 5e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
# hold state during the cutting process so we know if an
|
||||
# example has been exhausted and needs to cycle
|
||||
last_taken: dict = {} # map: eg -> pointer to last taken sample
|
||||
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
|
||||
|
||||
states = None
|
||||
for (
|
||||
Y_min,
|
||||
cu_seqlens,
|
||||
seq_idx,
|
||||
(A, dt, X, B, C),
|
||||
) in generate_continuous_batched_examples(
|
||||
cases, num_examples, seqlen, last_taken, exhausted, n_heads, d_head, itype
|
||||
):
|
||||
|
||||
chunk_indices, chunk_offsets = (
|
||||
Mamba2Metadata._query_start_loc_to_chunk_indices_offsets(
|
||||
cu_seqlens, chunk_size, cu_seqlens[-1]
|
||||
)
|
||||
)
|
||||
|
||||
Y = torch.empty_like(X)
|
||||
new_states = mamba_chunk_scan_combined(
|
||||
X,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
D=None,
|
||||
cu_seqlens=cu_seqlens,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
return_varlen_states=True,
|
||||
initial_states=states,
|
||||
out=Y,
|
||||
)
|
||||
|
||||
# just test the last in sequence
|
||||
for i in range(num_examples):
|
||||
|
||||
# just test one dim and dstate
|
||||
Y_eg = Y[0, cu_seqlens[i] : cu_seqlens[i + 1], 0, 0]
|
||||
Y_min_eg = Y_min[i][:, 0, 0]
|
||||
torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol)
|
||||
|
||||
# update states
|
||||
states = new_states
|
||||
for i, clear in exhausted.items():
|
||||
if clear:
|
||||
states[i].fill_(0.0)
|
||||
exhausted[i] = False
|
||||
|
||||
|
||||
@pytest.mark.parametrize("chunk_size", [8, 256])
|
||||
@pytest.mark.parametrize(
|
||||
"seqlens",
|
||||
[
|
||||
(16, 2, 8, 13),
|
||||
(270, 88, 212, 203),
|
||||
(16, 20),
|
||||
],
|
||||
)
|
||||
def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA device not available")
|
||||
|
||||
# This test verifies the correctness of the chunked prefill implementation
|
||||
# in the mamba2 ssd kernels, by comparing concatenation (in the sequence
|
||||
# dimension) of chunked results with the full sequence result.
|
||||
# It is different from test_mamba_chunk_scan_cont_batch by:
|
||||
# 1. Not using the naive torch implementation (ssd_minimal_discrete) to get
|
||||
# reference outputs. Instead, it compares chunked kernel outputs to full
|
||||
# sequence kernel outputs. This is the most straightforward way to
|
||||
# assert chunked prefill correctness.
|
||||
# 2. It focuses on cases where sequences change in the middle of mamba
|
||||
# chunks, and not necessarily on chunk boundaries.
|
||||
|
||||
max_seqlen = max(seqlens)
|
||||
# This test can have larger error for longer sequences
|
||||
if max_seqlen > 256:
|
||||
atol, rtol = 1e-2, 5e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
num_sequences = len(seqlens)
|
||||
n_heads = 16
|
||||
d_head = 64
|
||||
itype = torch.float32
|
||||
|
||||
# hold state during the cutting process so we know if an
|
||||
# example has been exhausted and needs to cycle
|
||||
last_taken: dict = {} # map: eg -> pointer to last taken sample
|
||||
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
|
||||
_, cu_seqlens, seq_idx, (A, dt, X, B, C) = next(
|
||||
generate_continuous_batched_examples(
|
||||
[seqlens],
|
||||
num_sequences,
|
||||
max_seqlen,
|
||||
last_taken,
|
||||
exhausted,
|
||||
n_heads,
|
||||
d_head,
|
||||
itype,
|
||||
return_naive_ref=False,
|
||||
)
|
||||
)
|
||||
seqlens = torch.tensor(seqlens, dtype=torch.int32, device=X.device)
|
||||
device = X.device
|
||||
|
||||
## full seqlen computation
|
||||
chunk_indices, chunk_offsets = (
|
||||
Mamba2Metadata._query_start_loc_to_chunk_indices_offsets(
|
||||
cu_seqlens, chunk_size, cu_seqlens[-1]
|
||||
)
|
||||
)
|
||||
Y_ref = torch.empty_like(X)
|
||||
state_ref = mamba_chunk_scan_combined(
|
||||
X,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
D=None,
|
||||
cu_seqlens=cu_seqlens,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
return_varlen_states=True,
|
||||
initial_states=None,
|
||||
out=Y_ref,
|
||||
)
|
||||
|
||||
## chunked seqlen computation
|
||||
# first chunk
|
||||
chunked_seqlens = seqlens // 2
|
||||
chunked_cu_seqlens = torch.cat(
|
||||
[torch.tensor([0], device=device), torch.cumsum(chunked_seqlens, dim=0)], dim=0
|
||||
)
|
||||
chunked_seq_idx = (
|
||||
torch.repeat_interleave(
|
||||
torch.arange(len(chunked_seqlens), device=device),
|
||||
chunked_seqlens,
|
||||
output_size=chunked_cu_seqlens[-1],
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.to(torch.int32)
|
||||
)
|
||||
chunked_input_seq_len = chunked_cu_seqlens[-1]
|
||||
X_chunked = torch.zeros_like(X)[:, :chunked_input_seq_len, ...]
|
||||
dt_chunked = torch.zeros_like(dt)[:, :chunked_input_seq_len, ...]
|
||||
B_chunked = torch.zeros_like(B)[:, :chunked_input_seq_len, ...]
|
||||
C_chunked = torch.zeros_like(C)[:, :chunked_input_seq_len, ...]
|
||||
for i in range(num_sequences):
|
||||
# fmt: off
|
||||
chunk_f = lambda x, i: x[:, cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501
|
||||
|
||||
X_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501
|
||||
dt_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501
|
||||
B_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501
|
||||
C_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
|
||||
# fmt: on
|
||||
|
||||
chunk_indices, chunk_offsets = (
|
||||
Mamba2Metadata._query_start_loc_to_chunk_indices_offsets(
|
||||
chunked_cu_seqlens, chunk_size, chunked_cu_seqlens[-1]
|
||||
)
|
||||
)
|
||||
Y_partial = torch.empty_like(X_chunked)
|
||||
partial_state = mamba_chunk_scan_combined(
|
||||
X_chunked,
|
||||
dt_chunked,
|
||||
A,
|
||||
B_chunked,
|
||||
C_chunked,
|
||||
chunk_size,
|
||||
D=None,
|
||||
cu_seqlens=chunked_cu_seqlens,
|
||||
seq_idx=chunked_seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
return_varlen_states=True,
|
||||
initial_states=None,
|
||||
out=Y_partial,
|
||||
)
|
||||
|
||||
# remaining chunk
|
||||
remaining_chunked_seqlens = seqlens - chunked_seqlens
|
||||
remaining_chunked_cu_seqlens = torch.cat(
|
||||
[
|
||||
torch.tensor([0], device=device),
|
||||
torch.cumsum(remaining_chunked_seqlens, dim=0),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
remaining_chunked_seq_idx = (
|
||||
torch.repeat_interleave(
|
||||
torch.arange(len(remaining_chunked_seqlens), device=device),
|
||||
remaining_chunked_seqlens,
|
||||
output_size=remaining_chunked_cu_seqlens[-1],
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.to(torch.int32)
|
||||
)
|
||||
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
|
||||
# fmt: off
|
||||
remaining_X_chunked = torch.zeros_like(X)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||
remaining_dt_chunked = torch.zeros_like(dt)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||
remaining_B_chunked = torch.zeros_like(B)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||
remaining_C_chunked = torch.zeros_like(C)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||
for i in range(num_sequences):
|
||||
remaining_chunk_f = lambda x, i: x[:, cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501
|
||||
|
||||
remaining_X_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501
|
||||
remaining_dt_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501
|
||||
remaining_B_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501
|
||||
remaining_C_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501
|
||||
|
||||
# assert input chunking is correct
|
||||
concat_chunk_f = lambda pt1, pt2, i: torch.cat([
|
||||
pt1[:,chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...],
|
||||
pt2[:,remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...],
|
||||
],
|
||||
dim=1)
|
||||
concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=1) # noqa: E501
|
||||
# fmt: on
|
||||
|
||||
assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X)
|
||||
assert concat_batch_f(dt_chunked, remaining_dt_chunked).equal(dt)
|
||||
assert concat_batch_f(B_chunked, remaining_B_chunked).equal(B)
|
||||
assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C)
|
||||
|
||||
chunk_indices, chunk_offsets = (
|
||||
Mamba2Metadata._query_start_loc_to_chunk_indices_offsets(
|
||||
remaining_chunked_cu_seqlens, chunk_size, remaining_chunked_cu_seqlens[-1]
|
||||
)
|
||||
)
|
||||
|
||||
Y_chunked = torch.empty_like(remaining_X_chunked)
|
||||
state_chunked = mamba_chunk_scan_combined(
|
||||
remaining_X_chunked,
|
||||
remaining_dt_chunked,
|
||||
A,
|
||||
remaining_B_chunked,
|
||||
remaining_C_chunked,
|
||||
chunk_size,
|
||||
D=None,
|
||||
cu_seqlens=remaining_chunked_cu_seqlens,
|
||||
seq_idx=remaining_chunked_seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
return_varlen_states=True,
|
||||
initial_states=partial_state,
|
||||
out=Y_chunked,
|
||||
)
|
||||
Y = concat_batch_f(Y_partial, Y_chunked)
|
||||
|
||||
# kernel chunked is same as kernel overall
|
||||
for i in range(num_sequences):
|
||||
Y_seq = Y[:, cu_seqlens[i] : cu_seqlens[i + 1], ...]
|
||||
Y_ref_seq = Y_ref[:, cu_seqlens[i] : cu_seqlens[i + 1], ...]
|
||||
torch.testing.assert_close(
|
||||
Y_seq[:, : chunked_seqlens[i], ...],
|
||||
Y_ref_seq[:, : chunked_seqlens[i], ...],
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg=lambda x: f"seq{i} output part1 " + x,
|
||||
) # noqa: B023
|
||||
torch.testing.assert_close(
|
||||
Y_seq[:, chunked_seqlens[i] :, ...],
|
||||
Y_ref_seq[:, chunked_seqlens[i] :, ...],
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg=lambda x: f"seq{i} output part2 " + x,
|
||||
) # noqa: B023
|
||||
|
||||
state_seq = state_chunked[i]
|
||||
state_seq_ref = state_ref[i]
|
||||
torch.testing.assert_close(
|
||||
state_seq,
|
||||
state_seq_ref,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg=lambda x: f"seq{i} state " + x,
|
||||
) # noqa: B023
|
||||
@@ -91,6 +91,11 @@ ALL_MODELS = [
|
||||
trust_remote_code=True,
|
||||
skip_long_prompt=True,
|
||||
),
|
||||
ModelCase(
|
||||
"nvidia/NVIDIA-Nemotron-Nano-9B-v2",
|
||||
trust_remote_code=True,
|
||||
skip_long_prompt=True,
|
||||
),
|
||||
ModelCase(
|
||||
"swiss-ai/Apertus-8B",
|
||||
trust_remote_code=True,
|
||||
|
||||
44
test/srt/models/test_nvidia_nemotron_nano_v2.py
Normal file
44
test/srt/models/test_nvidia_nemotron_nano_v2.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.few_shot_gsm8k import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestNvidiaNemotronNanoV2(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "nvidia/NVIDIA-Nemotron-Nano-9B-v2"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--max-mamba-cache-size",
|
||||
"256",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
print(f"{metrics=}")
|
||||
self.assertGreater(metrics["accuracy"], 0.87)
|
||||
@@ -127,6 +127,10 @@ suites = {
|
||||
TestFile("test_vlm_input_format.py", 300),
|
||||
TestFile("test_vision_openai_server_a.py", 724),
|
||||
TestFile("test_vision_openai_server_b.py", 446),
|
||||
TestFile("layers/attention/mamba/test_causal_conv1d.py", 85),
|
||||
TestFile("layers/attention/mamba/test_mamba_ssm.py", 85),
|
||||
TestFile("layers/attention/mamba/test_mamba_ssm_ssd.py", 220),
|
||||
TestFile("models/test_nvidia_nemotron_nano_v2.py", 180),
|
||||
TestFile("test_modelopt_loader.py", 30),
|
||||
],
|
||||
"per-commit-2-gpu": [
|
||||
@@ -142,6 +146,7 @@ suites = {
|
||||
TestFile("hicache/test_hicache_storage_file_backend.py", 200),
|
||||
TestFile("hicache/test_hicache_storage_mooncake_backend.py", 400),
|
||||
TestFile("hicache/test_hicache_storage_3fs_backend.py", 200),
|
||||
TestFile("layers/attention/mamba/test_mamba2_mixer.py", 110),
|
||||
],
|
||||
"per-commit-4-gpu": [
|
||||
TestFile("test_gpt_oss_4gpu.py", 300),
|
||||
|
||||
Reference in New Issue
Block a user