Sync from v0.13
This commit is contained in:
373
tests/kernels/mamba/test_causal_conv1d.py
Normal file
373
tests/kernels/mamba/test_causal_conv1d.py
Normal file
@@ -0,0 +1,373 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn,
|
||||
causal_conv1d_update,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def causal_conv1d_ref(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
initial_states: torch.Tensor | None = None,
|
||||
return_final_states: bool = False,
|
||||
final_states_out: torch.Tensor | None = None,
|
||||
activation: str | None = "silu",
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen)
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
initial_states: (batch, dim, width - 1)
|
||||
final_states_out: (batch, dim, width - 1)
|
||||
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
dtype_in = x.dtype
|
||||
x = x.to(weight.dtype)
|
||||
seqlen = x.shape[-1]
|
||||
dim, width = weight.shape
|
||||
if initial_states is None:
|
||||
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
|
||||
else:
|
||||
x = torch.cat([initial_states, x], dim=-1)
|
||||
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
|
||||
out = out[..., :seqlen]
|
||||
if return_final_states:
|
||||
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
|
||||
dtype_in
|
||||
) # (batch, dim, width - 1)
|
||||
if final_states_out is not None:
|
||||
final_states_out.copy_(final_states)
|
||||
else:
|
||||
final_states_out = final_states
|
||||
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
||||
return (out, None) if not return_final_states else (out, final_states_out)
|
||||
|
||||
|
||||
def causal_conv1d_update_ref(
|
||||
x, conv_state, weight, bias=None, activation=None, cache_seqlens=None
|
||||
):
|
||||
"""
|
||||
x: (batch, dim) or (batch, dim, seqlen)
|
||||
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
cache_seqlens: (batch,), dtype int32.
|
||||
If not None, the conv_state is treated as a circular buffer.
|
||||
The conv_state will be updated by copying x to the
|
||||
conv_state starting at the index
|
||||
@cache_seqlens % state_len before performing the convolution.
|
||||
|
||||
out: (batch, dim) or (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
dtype_in = x.dtype
|
||||
unsqueeze = x.dim() == 2
|
||||
if unsqueeze:
|
||||
x = x.unsqueeze(-1)
|
||||
batch, dim, seqlen = x.shape
|
||||
width = weight.shape[1]
|
||||
state_len = conv_state.shape[-1]
|
||||
assert conv_state.shape == (batch, dim, state_len)
|
||||
assert weight.shape == (dim, width)
|
||||
if cache_seqlens is None:
|
||||
x_new = torch.cat([conv_state, x], dim=-1).to(
|
||||
weight.dtype
|
||||
) # (batch, dim, state_len + seqlen)
|
||||
conv_state.copy_(x_new[:, :, -state_len:])
|
||||
else:
|
||||
width_idx = torch.arange(
|
||||
-(width - 1), 0, dtype=torch.long, device=x.device
|
||||
).unsqueeze(0) + cache_seqlens.unsqueeze(1)
|
||||
width_idx = (
|
||||
torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
||||
)
|
||||
x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
|
||||
copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(
|
||||
0
|
||||
) + cache_seqlens.unsqueeze(1)
|
||||
copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
||||
conv_state.scatter_(2, copy_idx, x)
|
||||
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[
|
||||
:, :, -seqlen:
|
||||
]
|
||||
if unsqueeze:
|
||||
out = out.squeeze(-1)
|
||||
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
|
||||
@pytest.mark.parametrize("silu_activation", [True])
|
||||
@pytest.mark.parametrize("has_bias", [True])
|
||||
def causal_conv1d_opcheck_fn(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
cu_seq_len: torch.Tensor | None = None,
|
||||
cache_indices: torch.Tensor | None = None,
|
||||
has_initial_state: torch.Tensor | None = None,
|
||||
conv_states: torch.Tensor | None = None,
|
||||
activation: str | None = "silu",
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen)
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
seq_idx: (batch, seqlen)
|
||||
initial_states: (batch, dim, width - 1)
|
||||
final_states_out: (batch, dim, width - 1), to be written to
|
||||
activation: either None or "silu" or "swish"
|
||||
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
bias = bias.contiguous() if bias is not None else None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [False, True])
|
||||
@pytest.mark.parametrize("has_bias", [False, True])
|
||||
@pytest.mark.parametrize("seqlen", [1])
|
||||
@pytest.mark.parametrize("width", [4])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, itype):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
# set seed
|
||||
current_platform.seed_everything(0)
|
||||
batch = 2
|
||||
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
|
||||
x_ref = x.clone()
|
||||
conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype)
|
||||
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
conv_state_ref = conv_state.detach().clone()
|
||||
activation = None if not silu_activation else "silu"
|
||||
|
||||
conv_state_indices = torch.arange(batch, dtype=torch.int32, device=device)
|
||||
|
||||
out = causal_conv1d_update(
|
||||
x,
|
||||
conv_state,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation,
|
||||
conv_state_indices=conv_state_indices,
|
||||
)
|
||||
out_ref = causal_conv1d_update_ref(
|
||||
x_ref, conv_state_ref, weight, bias, activation=activation
|
||||
)
|
||||
|
||||
assert torch.equal(conv_state, conv_state_ref)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [False, True])
|
||||
@pytest.mark.parametrize("has_bias", [False, True])
|
||||
@pytest.mark.parametrize("seqlen", [1, 3])
|
||||
@pytest.mark.parametrize("width", [3, 4])
|
||||
@pytest.mark.parametrize("dim", [2048 + 16, 4096])
|
||||
# tests correctness in case subset of the sequences are padded
|
||||
@pytest.mark.parametrize("with_padding", [True, False])
|
||||
@pytest.mark.parametrize("batch_size", [3])
|
||||
def test_causal_conv1d_update_with_batch_gather(
|
||||
batch_size, with_padding, dim, width, seqlen, has_bias, silu_activation, itype
|
||||
):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
|
||||
# set seed
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
padding = 5 if with_padding else 0
|
||||
padded_batch_size = batch_size + padding
|
||||
# total_entries = number of cache line
|
||||
total_entries = 10 * batch_size
|
||||
|
||||
# x will be (batch, dim, seqlen) with contiguous along dim-axis
|
||||
x = torch.randn(
|
||||
padded_batch_size, seqlen, dim, device=device, dtype=itype
|
||||
).transpose(1, 2)
|
||||
|
||||
x_ref = x.clone()
|
||||
|
||||
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||
dtype=torch.int32, device=device
|
||||
)
|
||||
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
|
||||
unused_states_bool[conv_state_indices] = False
|
||||
padded_state_indices = torch.concat(
|
||||
[
|
||||
conv_state_indices,
|
||||
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# conv_state will be (cache_lines, dim, state_len)
|
||||
# with contiguous along dim-axis
|
||||
conv_state = torch.randn(
|
||||
total_entries, width - 1, dim, device=device, dtype=itype
|
||||
).transpose(1, 2)
|
||||
|
||||
conv_state_for_padding_test = conv_state.clone()
|
||||
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
|
||||
activation = None if not silu_activation else "silu"
|
||||
|
||||
out = causal_conv1d_update(
|
||||
x,
|
||||
conv_state,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation,
|
||||
conv_state_indices=padded_state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
)
|
||||
out_ref = causal_conv1d_update_ref(
|
||||
x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation
|
||||
)
|
||||
|
||||
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
|
||||
assert torch.equal(
|
||||
conv_state[unused_states_bool], conv_state_for_padding_test[unused_states_bool]
|
||||
)
|
||||
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [True])
|
||||
@pytest.mark.parametrize("has_bias", [True])
|
||||
@pytest.mark.parametrize("width", [4])
|
||||
@pytest.mark.parametrize("seqlen", [8, 249, 4096])
|
||||
@pytest.mark.parametrize("dim", [64, 4096])
|
||||
@pytest.mark.parametrize("with_padding", [True, False])
|
||||
@pytest.mark.parametrize("batch", [4, 10])
|
||||
def test_causal_conv1d_varlen(
|
||||
batch, with_padding, dim, seqlen, width, has_bias, silu_activation, itype
|
||||
):
|
||||
device = "cuda"
|
||||
torch.cuda.empty_cache()
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
# set seed
|
||||
current_platform.seed_everything(0)
|
||||
seqlens = []
|
||||
batch_size = batch
|
||||
padding = 3 if with_padding else 0
|
||||
padded_batch_size = batch_size + padding
|
||||
nsplits = padded_batch_size - 1
|
||||
|
||||
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
|
||||
|
||||
seqlens.append(
|
||||
torch.diff(
|
||||
torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])
|
||||
).tolist()
|
||||
)
|
||||
assert sum(seqlens[-1]) == seqlen
|
||||
assert all(s > 0 for s in seqlens[-1])
|
||||
|
||||
total_entries = batch_size * 10
|
||||
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
|
||||
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0)
|
||||
x = rearrange(
|
||||
torch.randn(1, seqlen, 4096 + dim + 64, device=device, dtype=itype),
|
||||
"b s d -> b d s",
|
||||
)[:, 4096 : 4096 + dim, :]
|
||||
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
x_ref = x.clone()
|
||||
weight_ref = weight.clone()
|
||||
bias_ref = bias.clone() if bias is not None else None
|
||||
activation = None if not silu_activation else "silu"
|
||||
final_states = torch.randn(
|
||||
total_entries, width - 1, dim, device=x.device, dtype=x.dtype
|
||||
).transpose(1, 2)
|
||||
final_states_ref = final_states.clone()
|
||||
has_initial_states = torch.randint(
|
||||
0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=x.device
|
||||
)
|
||||
state_indices = torch.randperm(total_entries, dtype=torch.int32, device=x.device)[
|
||||
:batch_size
|
||||
]
|
||||
padded_state_indices = torch.concat(
|
||||
[
|
||||
state_indices,
|
||||
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
out = causal_conv1d_fn(
|
||||
x.squeeze(0),
|
||||
weight,
|
||||
bias=bias,
|
||||
conv_states=final_states,
|
||||
query_start_loc=cumsum.cuda(),
|
||||
cache_indices=padded_state_indices,
|
||||
has_initial_state=has_initial_states,
|
||||
activation=activation,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
)
|
||||
|
||||
out_ref = []
|
||||
out_ref_b = []
|
||||
|
||||
splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
|
||||
for i in range(len(seqlens[0])):
|
||||
x_s = [v[i].unsqueeze(0) for v in splits][0]
|
||||
if padded_state_indices[i] == PAD_SLOT_ID:
|
||||
continue
|
||||
out_ref_b.append(
|
||||
causal_conv1d_ref(
|
||||
x_s,
|
||||
weight_ref,
|
||||
bias_ref,
|
||||
activation=activation,
|
||||
return_final_states=True,
|
||||
final_states_out=final_states_ref[padded_state_indices[i]].unsqueeze(0),
|
||||
initial_states=final_states_ref[padded_state_indices[i]].unsqueeze(0)
|
||||
if has_initial_states[i]
|
||||
else None,
|
||||
)
|
||||
)
|
||||
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
|
||||
out_ref_tensor = torch.cat(out_ref, dim=0)
|
||||
|
||||
assert torch.allclose(
|
||||
final_states[state_indices],
|
||||
final_states_ref[state_indices],
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
)
|
||||
unpadded_out = out[:, : out_ref_tensor.shape[-1]]
|
||||
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
|
||||
137
tests/kernels/mamba/test_mamba_mixer2.py
Normal file
137
tests/kernels/mamba/test_mamba_mixer2.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.utils import multi_gpu_test
|
||||
from vllm.distributed.parallel_state import (
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import Mixer2RMSNormGated
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seq_len", [128])
|
||||
@pytest.mark.parametrize(
|
||||
"hidden_size_n_groups",
|
||||
[
|
||||
(64, 1),
|
||||
(64, 2),
|
||||
(64, 4), # hidden_size be divisible by num_gpus
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||
def test_mixer2_gated_norm_multi_gpu(
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size_n_groups: tuple[int, int],
|
||||
dtype: torch.dtype,
|
||||
device: str = "cuda",
|
||||
):
|
||||
hidden_size, n_groups = hidden_size_n_groups
|
||||
num_processes = 2
|
||||
|
||||
def run_torch_spawn(fn, nprocs):
|
||||
# need to use torch.mp.spawn otherwise will have problems with
|
||||
# torch.distributed and cuda
|
||||
torch.multiprocessing.spawn(
|
||||
fn,
|
||||
args=(
|
||||
num_processes,
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
n_groups,
|
||||
dtype,
|
||||
device,
|
||||
),
|
||||
nprocs=nprocs,
|
||||
)
|
||||
|
||||
run_torch_spawn(mixer2_gated_norm_tensor_parallel, 2)
|
||||
|
||||
|
||||
def mixer2_gated_norm_tensor_parallel(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
n_groups: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
update_environment_variables(
|
||||
{
|
||||
"RANK": str(local_rank),
|
||||
"LOCAL_RANK": str(local_rank),
|
||||
"WORLD_SIZE": str(world_size),
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
}
|
||||
)
|
||||
|
||||
# initialize distributed
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# create random weights an inputs
|
||||
weight = torch.rand((hidden_size,), dtype=dtype, device=device)
|
||||
hidden_states = torch.randn(batch_size, seq_len, hidden_size)
|
||||
gate_states = torch.randn(batch_size, seq_len, hidden_size)
|
||||
|
||||
# create gated-norm with TP
|
||||
mixer = Mixer2RMSNormGated(
|
||||
full_hidden_size=hidden_size,
|
||||
full_n_groups=n_groups,
|
||||
)
|
||||
mixer.weight.weight_loader(mixer.weight, weight) # load
|
||||
|
||||
# create gated-norm without TP to compute reference
|
||||
# - utilize mock patching to disable TP when
|
||||
with (
|
||||
unittest.mock.patch(
|
||||
"vllm.model_executor.layers.mamba.mamba_mixer2."
|
||||
"get_tensor_model_parallel_world_size",
|
||||
return_value=1,
|
||||
),
|
||||
unittest.mock.patch(
|
||||
"vllm.model_executor.layers.mamba.mamba_mixer2."
|
||||
"get_tensor_model_parallel_rank",
|
||||
return_value=0,
|
||||
),
|
||||
):
|
||||
mixer_single_gpu = Mixer2RMSNormGated(
|
||||
full_hidden_size=hidden_size,
|
||||
full_n_groups=n_groups,
|
||||
)
|
||||
# assign weight to single-gpu mixer
|
||||
mixer_single_gpu.weight.data = weight
|
||||
|
||||
# generate and compare
|
||||
N = hidden_size // world_size
|
||||
output = mixer(
|
||||
hidden_states[..., local_rank * N : (local_rank + 1) * N],
|
||||
gate_states[..., local_rank * N : (local_rank + 1) * N],
|
||||
)
|
||||
ref_output = mixer_single_gpu(hidden_states, gate_states)
|
||||
torch.testing.assert_close(
|
||||
output,
|
||||
ref_output[..., local_rank * N : (local_rank + 1) * N],
|
||||
atol=5e-3,
|
||||
rtol=1e-3,
|
||||
)
|
||||
1093
tests/kernels/mamba/test_mamba_ssm.py
Normal file
1093
tests/kernels/mamba/test_mamba_ssm.py
Normal file
File diff suppressed because it is too large
Load Diff
569
tests/kernels/mamba/test_mamba_ssm_ssd.py
Normal file
569
tests/kernels/mamba/test_mamba_ssm_ssd.py
Normal file
@@ -0,0 +1,569 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||
mamba_chunk_scan_combined_varlen,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.mamba2_attn import compute_varlen_chunk_metadata
|
||||
|
||||
# Added by the IBM Team, 2024
|
||||
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py
|
||||
|
||||
|
||||
# this is the segsum implementation taken from above
|
||||
def segsum(x):
|
||||
"""Calculates segment sum."""
|
||||
T = x.size(-1)
|
||||
x = repeat(x, "... d -> ... d e", e=T)
|
||||
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
|
||||
x = x.masked_fill(~mask, 0)
|
||||
x_segsum = torch.cumsum(x, dim=-2)
|
||||
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
|
||||
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
|
||||
return x_segsum
|
||||
|
||||
|
||||
def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
|
||||
"""
|
||||
Arguments:
|
||||
X: (batch, length, n_heads, d_head)
|
||||
A: (batch, length, n_heads)
|
||||
B: (batch, length, n_heads, d_state)
|
||||
C: (batch, length, n_heads, d_state)
|
||||
Return:
|
||||
Y: (batch, length, n_heads, d_head)
|
||||
"""
|
||||
assert X.dtype == A.dtype == B.dtype == C.dtype
|
||||
assert X.shape[1] % block_len == 0
|
||||
|
||||
# Rearrange into blocks/chunks
|
||||
X, A, B, C = (
|
||||
rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)
|
||||
)
|
||||
|
||||
A = rearrange(A, "b c l h -> b h c l")
|
||||
A_cumsum = torch.cumsum(A, dim=-1)
|
||||
|
||||
# 1. Compute the output for each intra-chunk (diagonal blocks)
|
||||
L = torch.exp(segsum(A))
|
||||
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
|
||||
|
||||
# 2. Compute the state for each intra-chunk
|
||||
# (right term of low-rank factorization of off-diagonal blocks; B terms)
|
||||
decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
|
||||
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
|
||||
|
||||
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at
|
||||
# chunk boundaries
|
||||
# (middle term of factorization of off-diag blocks; A terms)
|
||||
if initial_states is None:
|
||||
initial_states = torch.zeros_like(states[:, :1])
|
||||
states = torch.cat([initial_states, states], dim=1)
|
||||
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
|
||||
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
|
||||
states, final_state = new_states[:, :-1], new_states[:, -1]
|
||||
|
||||
# 4. Compute state -> output conversion per chunk
|
||||
# (left term of low-rank factorization of off-diagonal blocks; C terms)
|
||||
state_decay_out = torch.exp(A_cumsum)
|
||||
Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out)
|
||||
|
||||
# Add output of intra-chunk and inter-chunk terms
|
||||
# (diagonal and off-diagonal blocks)
|
||||
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
|
||||
return Y, final_state
|
||||
|
||||
|
||||
def generate_random_inputs(batch_size, seqlen, n_heads, d_head, itype, device="cuda"):
|
||||
current_platform.seed_everything(0)
|
||||
A = -torch.exp(torch.rand(n_heads, dtype=itype, device=device))
|
||||
dt = F.softplus(
|
||||
torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) - 4
|
||||
)
|
||||
X = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device)
|
||||
B = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device)
|
||||
C = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device)
|
||||
|
||||
return A, dt, X, B, C
|
||||
|
||||
|
||||
def generate_continuous_batched_examples(
|
||||
example_lens_by_batch,
|
||||
num_examples,
|
||||
full_length,
|
||||
last_taken,
|
||||
exhausted,
|
||||
n_heads,
|
||||
d_head,
|
||||
itype,
|
||||
device="cuda",
|
||||
return_naive_ref=True,
|
||||
):
|
||||
# this function generates a random examples of certain length
|
||||
# and then cut according to "example_lens_by_batch" and feed
|
||||
# them in continuous batches to the kernels.
|
||||
# If if return_naive_ref=True, the naive torch implementation
|
||||
# ssd_minimal_discrete will be used to compute and return
|
||||
# reference output.
|
||||
|
||||
# generate the full-length example
|
||||
A, dt, X, B, C = generate_random_inputs(
|
||||
num_examples, full_length, n_heads, d_head, itype
|
||||
)
|
||||
|
||||
if return_naive_ref:
|
||||
Y_min, final_state_min = ssd_minimal_discrete(
|
||||
X * dt.unsqueeze(-1), A * dt, B, C, block_len=full_length // 4
|
||||
)
|
||||
|
||||
# internal function that outputs a cont batch of examples
|
||||
# given a tuple of lengths for each example in the batch
|
||||
# e.g., example_lens=(8, 4) means take 8 samples from first eg,
|
||||
# 4 examples from second eg, etc
|
||||
def get_continuous_batch(example_lens: tuple[int, ...]):
|
||||
indices = []
|
||||
for i, x in enumerate(example_lens):
|
||||
c = last_taken.get(i, 0)
|
||||
indices.append((c, c + x))
|
||||
last_taken[i] = (c + x) % full_length
|
||||
exhausted[i] = last_taken[i] == 0
|
||||
|
||||
return (
|
||||
torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices)]).unsqueeze(0)
|
||||
for x in (dt, X, B, C)
|
||||
)
|
||||
|
||||
# internal function that maps "n" to the appropriate right boundary
|
||||
# value when forming continuous batches from examples of length given
|
||||
# by "full_length".
|
||||
# - e.g., when n > full_length, returns n % full_length
|
||||
# when n == full_length, returns full_length
|
||||
def end_boundary(n: int):
|
||||
return n - ((n - 1) // full_length) * full_length
|
||||
|
||||
IND_E = None
|
||||
for spec in example_lens_by_batch:
|
||||
# get the (maybe partial) example seen in this cont batch
|
||||
dt2, X2, B2, C2 = get_continuous_batch(spec)
|
||||
|
||||
# get the metadata
|
||||
cu_seqlens = torch.tensor((0,) + spec, device=device).cumsum(dim=0)
|
||||
seq_idx = torch.zeros(
|
||||
cu_seqlens[-1], dtype=torch.int32, device=cu_seqlens.device
|
||||
)
|
||||
for i, (srt, end) in enumerate(
|
||||
zip(
|
||||
cu_seqlens,
|
||||
cu_seqlens[1:],
|
||||
)
|
||||
):
|
||||
seq_idx[srt:end] = i
|
||||
|
||||
# for cont batch
|
||||
if IND_E is None:
|
||||
IND_S = [0 for _ in range(len(spec))]
|
||||
else:
|
||||
IND_S = [x % full_length for x in IND_E]
|
||||
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]
|
||||
|
||||
# varlen has implicit batch=1
|
||||
dt2 = dt2.squeeze(0)
|
||||
X2 = X2.squeeze(0)
|
||||
B2 = B2.squeeze(0)
|
||||
C2 = C2.squeeze(0)
|
||||
yield (
|
||||
[Y_min[s, IND_S[s] : IND_E[s]] for s in range(num_examples)]
|
||||
if return_naive_ref
|
||||
else None,
|
||||
cu_seqlens,
|
||||
seq_idx,
|
||||
(A, dt2, X2, B2, C2),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("n_heads", [4, 16, 32])
|
||||
@pytest.mark.parametrize("d_head", [5, 8, 32, 128])
|
||||
@pytest.mark.parametrize("seq_len_chunk_size", [(112, 16), (128, 32)])
|
||||
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, itype):
|
||||
# this tests the kernels on a single example (bs=1)
|
||||
|
||||
# TODO: the bfloat16 case requires higher thresholds. To be investigated
|
||||
|
||||
if itype == torch.bfloat16:
|
||||
atol, rtol = 5e-2, 5e-2
|
||||
else:
|
||||
atol, rtol = 8e-3, 5e-3
|
||||
|
||||
# set seed
|
||||
batch_size = 1 # batch_size
|
||||
# ssd_minimal_discrete requires chunk_size divide seqlen
|
||||
# - this is only required for generating the reference seqs,
|
||||
# it is not an operational limitation.
|
||||
seqlen, chunk_size = seq_len_chunk_size
|
||||
|
||||
A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads, d_head, itype)
|
||||
|
||||
Y_min, final_state_min = ssd_minimal_discrete(
|
||||
X * dt.unsqueeze(-1), A * dt, B, C, chunk_size
|
||||
)
|
||||
|
||||
cu_seqlens = torch.tensor((0, seqlen), device="cuda").cumsum(dim=0)
|
||||
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||
compute_varlen_chunk_metadata(cu_seqlens, chunk_size)
|
||||
)
|
||||
# varlen has implicit batch=1
|
||||
X = X.squeeze(0)
|
||||
dt = dt.squeeze(0)
|
||||
A = A.squeeze(0)
|
||||
B = B.squeeze(0)
|
||||
C = C.squeeze(0)
|
||||
Y = torch.empty_like(X)
|
||||
final_state = mamba_chunk_scan_combined_varlen(
|
||||
X,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
cu_seqlens=cu_seqlens.to(torch.int32),
|
||||
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||
last_chunk_indices=last_chunk_indices,
|
||||
seq_idx=seq_idx_chunks,
|
||||
out=Y,
|
||||
D=None,
|
||||
)
|
||||
|
||||
# just test the last in sequence
|
||||
torch.testing.assert_close(Y[-1], Y_min[0, -1], atol=atol, rtol=rtol)
|
||||
|
||||
# just test the last head
|
||||
# NOTE, in the kernel we always cast states to fp32
|
||||
torch.testing.assert_close(
|
||||
final_state[:, -1].to(torch.float32),
|
||||
final_state_min[:, -1].to(torch.float32),
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.float32])
|
||||
@pytest.mark.parametrize("n_heads", [4, 8])
|
||||
@pytest.mark.parametrize("d_head", [5, 16, 32])
|
||||
@pytest.mark.parametrize(
|
||||
"seq_len_chunk_size_cases",
|
||||
[
|
||||
# small-ish chunk_size (8)
|
||||
(64, 8, 2, [(64, 32), (64, 32)]),
|
||||
(64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary
|
||||
(
|
||||
64,
|
||||
8,
|
||||
2,
|
||||
[(4, 4), (4, 4), (4, 4), (4, 4)],
|
||||
), # chunk_size larger than cont batches
|
||||
(64, 8, 5, [(64, 32, 16, 8, 8)]),
|
||||
# large-ish chunk_size (256)
|
||||
(64, 256, 1, [(5,), (1,), (1,), (1,)]), # irregular sizes with small sequences
|
||||
(
|
||||
64,
|
||||
256,
|
||||
2,
|
||||
[(5, 30), (1, 2), (1, 2), (1, 2)],
|
||||
), # irregular sizes with small sequences
|
||||
# we also need to test some large seqlen
|
||||
# to catch errors with init states decay
|
||||
(768, 128, 2, [(138, 225), (138, 225)]),
|
||||
],
|
||||
)
|
||||
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, itype):
|
||||
# this test with multiple examples in a continuous batch
|
||||
# (i.e. chunked prefill)
|
||||
|
||||
seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
|
||||
|
||||
# This test can have larger error for longer sequences
|
||||
if seqlen > 256:
|
||||
atol, rtol = 1e-2, 5e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
# hold state during the cutting process so we know if an
|
||||
# example has been exhausted and needs to cycle
|
||||
last_taken: dict = {} # map: eg -> pointer to last taken sample
|
||||
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
|
||||
|
||||
states = None
|
||||
for Y_min, cu_seqlens, _token_seq_idx, (
|
||||
A,
|
||||
dt,
|
||||
X,
|
||||
B,
|
||||
C,
|
||||
) in generate_continuous_batched_examples(
|
||||
cases, num_examples, seqlen, last_taken, exhausted, n_heads, d_head, itype
|
||||
):
|
||||
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||
compute_varlen_chunk_metadata(cu_seqlens, chunk_size)
|
||||
)
|
||||
|
||||
Y = torch.empty_like(X)
|
||||
new_states = mamba_chunk_scan_combined_varlen(
|
||||
X,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
cu_seqlens=cu_seqlens.to(torch.int32),
|
||||
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||
last_chunk_indices=last_chunk_indices,
|
||||
seq_idx=seq_idx_chunks,
|
||||
out=Y,
|
||||
D=None,
|
||||
initial_states=states,
|
||||
)
|
||||
|
||||
# just test the last in sequence
|
||||
for i in range(num_examples):
|
||||
# just test one dim and dstate
|
||||
Y_eg = Y[cu_seqlens[i] : cu_seqlens[i + 1], 0, 0]
|
||||
Y_min_eg = Y_min[i][:, 0, 0]
|
||||
torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol)
|
||||
|
||||
# update states
|
||||
states = new_states
|
||||
for i, clear in exhausted.items():
|
||||
if clear:
|
||||
states[i].fill_(0.0)
|
||||
exhausted[i] = False
|
||||
|
||||
|
||||
@pytest.mark.parametrize("chunk_size", [8, 256])
|
||||
@pytest.mark.parametrize(
|
||||
"seqlens",
|
||||
[(16, 20), (270, 88, 212, 203)],
|
||||
)
|
||||
def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
# This test verifies the correctness of the chunked prefill implementation
|
||||
# in the mamba2 ssd kernels, by comparing concatenation (in the sequence
|
||||
# dimension) of chunked results with the full sequence result.
|
||||
# It is different from test_mamba_chunk_scan_cont_batch by:
|
||||
# 1. Not using the naive torch implementation (ssd_minimal_discrete) to get
|
||||
# reference outputs. Instead, it compares chunked kernel outputs to full
|
||||
# sequence kernel outputs. This is the most straightforward way to
|
||||
# assert chunked prefill correctness.
|
||||
# 2. It focuses on cases where sequences change in the middle of mamba
|
||||
# chunks, and not necessarily on chunk boundaries.
|
||||
|
||||
max_seqlen = max(seqlens)
|
||||
# This test can have larger error for longer sequences
|
||||
if max_seqlen > 256:
|
||||
atol, rtol = 1e-2, 5e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
num_sequences = len(seqlens)
|
||||
n_heads = 16
|
||||
d_head = 64
|
||||
itype = torch.float32
|
||||
|
||||
# hold state during the cutting process so we know if an
|
||||
# example has been exhausted and needs to cycle
|
||||
last_taken: dict = {} # map: eg -> pointer to last taken sample
|
||||
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
|
||||
_, cu_seqlens, seq_idx, (A, dt, X, B, C) = next(
|
||||
generate_continuous_batched_examples(
|
||||
[seqlens],
|
||||
num_sequences,
|
||||
max_seqlen,
|
||||
last_taken,
|
||||
exhausted,
|
||||
n_heads,
|
||||
d_head,
|
||||
itype,
|
||||
return_naive_ref=False,
|
||||
)
|
||||
)
|
||||
seqlens = torch.tensor(seqlens, dtype=torch.int32, device=X.device)
|
||||
device = X.device
|
||||
|
||||
## full seqlen computation
|
||||
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||
compute_varlen_chunk_metadata(cu_seqlens, chunk_size)
|
||||
)
|
||||
Y_ref = torch.empty_like(X)
|
||||
state_ref = mamba_chunk_scan_combined_varlen(
|
||||
X,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
cu_seqlens=cu_seqlens.to(torch.int32),
|
||||
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||
last_chunk_indices=last_chunk_indices,
|
||||
seq_idx=seq_idx_chunks,
|
||||
out=Y_ref,
|
||||
D=None,
|
||||
initial_states=None,
|
||||
)
|
||||
|
||||
## chunked seqlen computation
|
||||
# first chunk
|
||||
chunked_seqlens = seqlens // 2
|
||||
chunked_cu_seqlens = torch.cat(
|
||||
[torch.tensor([0], device=device), torch.cumsum(chunked_seqlens, dim=0)], dim=0
|
||||
)
|
||||
chunked_input_seq_len = chunked_cu_seqlens[-1]
|
||||
X_chunked = torch.zeros_like(X)[:chunked_input_seq_len, ...]
|
||||
dt_chunked = torch.zeros_like(dt)[:chunked_input_seq_len, ...]
|
||||
B_chunked = torch.zeros_like(B)[:chunked_input_seq_len, ...]
|
||||
C_chunked = torch.zeros_like(C)[:chunked_input_seq_len, ...]
|
||||
for i in range(num_sequences):
|
||||
chunk_f = lambda x, i: x[
|
||||
cu_seqlens[i] : cu_seqlens[i] + chunked_seqlens[i], ...
|
||||
]
|
||||
|
||||
X_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f(
|
||||
X, i
|
||||
)
|
||||
dt_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f(
|
||||
dt, i
|
||||
)
|
||||
B_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f(
|
||||
B, i
|
||||
)
|
||||
C_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f(
|
||||
C, i
|
||||
)
|
||||
|
||||
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||
compute_varlen_chunk_metadata(chunked_cu_seqlens, chunk_size)
|
||||
)
|
||||
Y_partial = torch.empty_like(X_chunked)
|
||||
partial_state = mamba_chunk_scan_combined_varlen(
|
||||
X_chunked,
|
||||
dt_chunked,
|
||||
A,
|
||||
B_chunked,
|
||||
C_chunked,
|
||||
chunk_size,
|
||||
cu_seqlens=chunked_cu_seqlens.to(torch.int32),
|
||||
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||
last_chunk_indices=last_chunk_indices,
|
||||
seq_idx=seq_idx_chunks,
|
||||
out=Y_partial,
|
||||
D=None,
|
||||
initial_states=None,
|
||||
)
|
||||
|
||||
# remaining chunk
|
||||
remaining_chunked_seqlens = seqlens - chunked_seqlens
|
||||
remaining_chunked_cu_seqlens = torch.cat(
|
||||
[
|
||||
torch.tensor([0], device=device),
|
||||
torch.cumsum(remaining_chunked_seqlens, dim=0),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
|
||||
remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...]
|
||||
remaining_dt_chunked = torch.zeros_like(dt)[:remaining_chunked_input_seq_len, ...]
|
||||
remaining_B_chunked = torch.zeros_like(B)[:remaining_chunked_input_seq_len, ...]
|
||||
remaining_C_chunked = torch.zeros_like(C)[:remaining_chunked_input_seq_len, ...]
|
||||
for i in range(num_sequences):
|
||||
remaining_chunk_f = lambda x, i: x[
|
||||
cu_seqlens[i] + chunked_seqlens[i] : cu_seqlens[i + 1], ...
|
||||
]
|
||||
|
||||
remaining_X_chunked[
|
||||
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ...
|
||||
] = remaining_chunk_f(X, i)
|
||||
remaining_dt_chunked[
|
||||
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ...
|
||||
] = remaining_chunk_f(dt, i)
|
||||
remaining_B_chunked[
|
||||
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ...
|
||||
] = remaining_chunk_f(B, i)
|
||||
remaining_C_chunked[
|
||||
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ...
|
||||
] = remaining_chunk_f(C, i)
|
||||
|
||||
# assert input chunking is correct
|
||||
concat_chunk_f = lambda pt1, pt2, i: torch.cat(
|
||||
[
|
||||
pt1[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...],
|
||||
pt2[
|
||||
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1],
|
||||
...,
|
||||
],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
concat_batch_f = lambda pt1, pt2: torch.cat(
|
||||
[concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=0
|
||||
)
|
||||
|
||||
assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X)
|
||||
assert concat_batch_f(dt_chunked, remaining_dt_chunked).equal(dt)
|
||||
assert concat_batch_f(B_chunked, remaining_B_chunked).equal(B)
|
||||
assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C)
|
||||
|
||||
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||
compute_varlen_chunk_metadata(remaining_chunked_cu_seqlens, chunk_size)
|
||||
)
|
||||
|
||||
Y_chunked = torch.empty_like(remaining_X_chunked)
|
||||
state_chunked = mamba_chunk_scan_combined_varlen(
|
||||
remaining_X_chunked,
|
||||
remaining_dt_chunked,
|
||||
A,
|
||||
remaining_B_chunked,
|
||||
remaining_C_chunked,
|
||||
chunk_size,
|
||||
cu_seqlens=remaining_chunked_cu_seqlens.to(torch.int32),
|
||||
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||
last_chunk_indices=last_chunk_indices,
|
||||
seq_idx=seq_idx_chunks,
|
||||
out=Y_chunked,
|
||||
D=None,
|
||||
initial_states=partial_state,
|
||||
)
|
||||
Y = concat_batch_f(Y_partial, Y_chunked)
|
||||
|
||||
# kernel chunked is same as kernel overall
|
||||
for i in range(num_sequences):
|
||||
Y_seq = Y[cu_seqlens[i] : cu_seqlens[i + 1], ...]
|
||||
Y_ref_seq = Y_ref[cu_seqlens[i] : cu_seqlens[i + 1], ...]
|
||||
torch.testing.assert_close(
|
||||
Y_seq[: chunked_seqlens[i], ...],
|
||||
Y_ref_seq[: chunked_seqlens[i], ...],
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg=lambda x, i=i: f"seq{i} output part1 " + x,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
Y_seq[chunked_seqlens[i] :, ...],
|
||||
Y_ref_seq[chunked_seqlens[i] :, ...],
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg=lambda x, i=i: f"seq{i} output part2 " + x,
|
||||
)
|
||||
|
||||
state_seq = state_chunked[i]
|
||||
state_seq_ref = state_ref[i]
|
||||
torch.testing.assert_close(
|
||||
state_seq,
|
||||
state_seq_ref,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg=lambda x, i=i: f"seq{i} state " + x,
|
||||
)
|
||||
Reference in New Issue
Block a user