85 lines
3.6 KiB
Python
85 lines
3.6 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
|
|
"""This function is intended to align with the l2norm implementation in the FLA library."""
|
|
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
|
|
return x * inv_norm
|
|
|
|
def torch_chunk_gated_delta_rule(
|
|
query,
|
|
key,
|
|
value,
|
|
g,
|
|
beta,
|
|
chunk_size=64,
|
|
initial_state=None,
|
|
output_final_state=False,
|
|
use_qk_l2norm_in_kernel=False,
|
|
):
|
|
initial_dtype = query.dtype
|
|
if use_qk_l2norm_in_kernel:
|
|
query = l2norm(query, dim=-1, eps=1e-6)
|
|
key = l2norm(key, dim=-1, eps=1e-6)
|
|
query, key, value, beta, g = [
|
|
x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
|
|
]
|
|
|
|
batch_size, num_heads, sequence_length, k_head_dim = key.shape
|
|
v_head_dim = value.shape[-1]
|
|
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
|
|
query = F.pad(query, (0, 0, 0, pad_size))
|
|
key = F.pad(key, (0, 0, 0, pad_size))
|
|
value = F.pad(value, (0, 0, 0, pad_size))
|
|
beta = F.pad(beta, (0, pad_size))
|
|
g = F.pad(g, (0, pad_size))
|
|
total_sequence_length = sequence_length + pad_size
|
|
scale = 1 / (query.shape[-1] ** 0.5)
|
|
query = query * scale
|
|
|
|
v_beta = value * beta.unsqueeze(-1)
|
|
k_beta = key * beta.unsqueeze(-1)
|
|
# reshape to chunks
|
|
query, key, value, k_beta, v_beta = [
|
|
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)
|
|
]
|
|
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
|
|
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
|
|
|
|
# chunk decay
|
|
g = g.cumsum(dim=-1)
|
|
decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
|
|
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
|
|
for i in range(1, chunk_size):
|
|
row = attn[..., i, :i].clone()
|
|
sub = attn[..., :i, :i].clone()
|
|
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
|
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
|
|
value = attn @ v_beta
|
|
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
|
|
last_recurrent_state = (
|
|
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
|
|
if initial_state is None
|
|
else initial_state.to(value)
|
|
)
|
|
core_attn_out = torch.zeros_like(value)
|
|
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
|
|
|
|
# for each chunk
|
|
for i in range(0, total_sequence_length // chunk_size):
|
|
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
|
|
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
|
|
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
|
v_new = v_i - v_prime
|
|
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
|
core_attn_out[:, :, i] = attn_inter + attn @ v_new
|
|
last_recurrent_state = (
|
|
last_recurrent_state * g[:, :, i, -1, None, None].exp()
|
|
+ (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
|
|
)
|
|
|
|
if not output_final_state:
|
|
last_recurrent_state = None
|
|
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
|
|
core_attn_out = core_attn_out[:, :, :sequence_length]
|
|
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
|
|
return core_attn_out, last_recurrent_state |