First commit

This commit is contained in:
2025-08-05 19:02:46 +08:00
parent 9efe891f99
commit 99fb9f5cb0
1412 changed files with 203615 additions and 0 deletions

View File

@@ -0,0 +1,8 @@
__version__ = "2.5.8"
from flash_attn.flash_attn_interface import flash_attn_func
from flash_attn.flash_attn_interface import flash_attn_kvpacked_func
from flash_attn.flash_attn_interface import flash_attn_qkvpacked_func
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func
from flash_attn.flash_attn_interface import flash_attn_varlen_func

View File

@@ -0,0 +1,132 @@
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
class IndexFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices]
return torch.gather(rearrange(input, 'b ... -> b (...)'), 0,
repeat(indices, 'z -> z d', d=second_dim)).reshape(-1, *other_shape)
@staticmethod
def backward(ctx, grad_output):
indices, = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
grad_output = rearrange(grad_output, 'b ... -> b (...)')
grad_input = torch.zeros([ctx.first_axis_dim, grad_output.shape[1]],
device=grad_output.device, dtype=grad_output.dtype)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# grad_input[indices] = grad_output
grad_input.scatter_(0, repeat(indices, 'z -> z d', d=grad_output.shape[1]), grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
index_first_axis = IndexFirstAxis.apply
class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, values, indices, first_axis_dim):
ctx.save_for_backward(indices)
assert indices.ndim == 1
assert values.ndim >= 2
output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device,
dtype=values.dtype)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
output[indices] = values
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
return output
@staticmethod
def backward(ctx, grad_output):
indices, = ctx.saved_tensors
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
grad_values = grad_output[indices]
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
return grad_values, None, None
index_put_first_axis = IndexPutFirstAxis.apply
class IndexFirstAxisResidual(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
output = input[indices]
# We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
# memory format to channel_first. In other words, input might not be contiguous.
# If we don't detach, Pytorch complains about output being a view and is being modified inplace
return output, input.detach()
@staticmethod
def backward(ctx, grad_output, grad_residual):
indices, = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
assert grad_residual.shape[1:] == other_shape
grad_input = grad_residual
# grad_input[indices] += grad_output
indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
indices = indices.expand_as(grad_output)
grad_input.scatter_add_(0, indices, grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
index_first_axis_residual = IndexFirstAxisResidual.apply
def unpad_input(hidden_states, attention_mask):
"""
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
"""
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'), indices), indices,
cu_seqlens, max_seqlen_in_batch)
def pad_input(hidden_states, indices, batch, seqlen):
"""
Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz)
Return:
hidden_states: (batch, seqlen, ...)
"""
dim = hidden_states.shape[-1]
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
# output[indices] = hidden_states
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
return rearrange(output, '(b s) ... -> b s ...', b=batch)

View File

@@ -0,0 +1,966 @@
import os
import torch
import torch.nn as nn
import flash_attn_2_cuda as flash_attn_cuda
from einops import rearrange
import warnings
global enable_ixdnn
enable_ixdnn = os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') != '0'
def _get_block_size(device, head_dim, is_dropout, is_causal):
# This should match the block sizes in the CUDA kernel
assert head_dim <= 256
major, minor = torch.cuda.get_device_capability(device)
is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100)
is_sm80 = major == 8 and minor == 0
is_sm90 = major == 9 and minor == 0
if head_dim <= 32:
return 128, 128
if head_dim <= 64:
return (128, 128) if not is_dropout else (128, 64)
elif head_dim <= 96:
return (64, 64) if (is_sm8x and is_causal) else (128, 64)
elif head_dim <= 128:
if is_sm8x:
return (64, 64) if (not is_dropout and is_causal) else (128, 32)
else:
return 128, (64 if not is_dropout else 32)
elif head_dim <= 160:
if is_sm8x:
return (128, 64) if not is_causal else (64, 64)
else:
return 128, 32
elif head_dim <= 192:
return (128, 64) if not is_dropout else (64, 64)
elif head_dim <= 224:
return (128, 64) if (is_sm80 or is_sm90) else (64, 64)
elif head_dim <= 256:
return (128, 64) if is_sm80 else (64, 64)
def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax, use_alibi, alibi_mode):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.fwd(
q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, use_alibi, alibi_mode, None
)
return out, q, k, v, out_padded, softmax_lse, S_dmask
def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal, return_softmax, use_alibi, alibi_mode):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.varlen_fwd(
q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, False, causal, return_softmax, use_alibi, alibi_mode, None
)
# if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
return out, q, k, v, out_padded, softmax_lse, S_dmask
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
dropout_p, softmax_scale, causal, use_alibi, alibi_mode):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, use_alibi, alibi_mode, None
)
return dq, dk, dv, softmax_d
def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal, use_alibi, alibi_mode):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, use_alibi, alibi_mode, None
)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
return dq, dk, dv, softmax_d
def _flash_attn_forward_ixdnn(q, k, v, cu_seqlens_q, cu_seqlens_k, dropout_p,
causal, return_softmax, use_alibi, alibi_mode, imp_mode, window_size):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.fwd_ixdnn(
q, k, v, None, cu_seqlens_q, cu_seqlens_k, dropout_p, causal, return_softmax,
use_alibi, alibi_mode, imp_mode, window_size[0], window_size[1], None
)
return out, q, k, v, out_padded, softmax_lse, S_dmask
def _flash_attn_backward_ixdnn(dout, q, k, v, out, softmax_lse, dq, dk, dv,
cu_seqlens_q, cu_seqlens_k, dropout_p, causal, use_alibi, alibi_mode, imp_mode, window_size):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd_ixdnn(
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, dropout_p,
causal, use_alibi, alibi_mode, imp_mode, window_size[0], window_size[1], None
)
return dq, dk, dv, softmax_d
def _flash_attn_varlen_forward_ixdnn(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, causal, use_alibi, alibi_mode, imp_mode, window_size):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.varlen_fwd_ixdnn(
q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
causal, use_alibi, alibi_mode, imp_mode, window_size[0], window_size[1], None
)
return out, q, k, v, out_padded, softmax_lse, S_dmask
def _flash_attn_varlen_backward_ixdnn(dout, q, k, v, out, softmax_lse, dq, dk, dv,
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, causal, use_alibi, alibi_mode, imp_mode, window_size):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d = flash_attn_cuda.varlen_bwd_ixdnn(
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, causal, use_alibi, alibi_mode, imp_mode,
window_size[0], window_size[1], None
)
return dq, dk, dv, softmax_d
class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax, use_alibi, alibi_mode, imp_mode):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if alibi_slopes is not None and use_alibi == False:
warnings.warn("Parameter 'alibi_slopes' is not supported, automatically switching to the supported parameters 'use_alibi=True' and 'alibi_mode=1'", UserWarning)
use_alibi = True
imp_mode = 1
if enable_ixdnn:
cu_seqlens = torch.full((qkv.shape[0],), qkv.shape[1], dtype=torch.int32, device=qkv.device)
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward_ixdnn(
qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], cu_seqlens, cu_seqlens, dropout_p,
causal=causal, return_softmax=return_softmax and dropout_p > 0,
use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode,
window_size=window_size
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
else:
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], dropout_p, softmax_scale,
causal=causal, return_softmax=return_softmax and dropout_p > 0,
use_alibi=use_alibi, alibi_mode=alibi_mode
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.use_alibi = use_alibi
ctx.alibi_mode = alibi_mode
ctx.imp_mode = imp_mode
ctx.window_size = window_size
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
def backward(ctx, dout, *args):
if enable_ixdnn:
q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
else:
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
if enable_ixdnn:
_flash_attn_backward_ixdnn(
dout, q, k, v, out, softmax_lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
cu_seqlens, cu_seqlens, ctx.dropout_p, ctx.causal, ctx.use_alibi, ctx.alibi_mode, ctx.imp_mode,
ctx.window_size
)
else:
_flash_attn_backward(
dout, q, k, v, out, softmax_lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
ctx.dropout_p, ctx.softmax_scale, ctx.causal, ctx.use_alibi, ctx.alibi_mode
)
dqkv = dqkv[..., :dout.shape[-1]] # We could have padded the head dimension
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None, None, None, None, None
class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax, use_alibi, alibi_mode, imp_mode):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if alibi_slopes is not None and use_alibi == False:
warnings.warn("Parameter 'alibi_slopes' is not supported, automatically switching to the supported parameters 'use_alibi=True' and 'alibi_mode=1'", UserWarning)
use_alibi = True
imp_mode = 1
if enable_ixdnn:
cu_seqlens = torch.full((cu_seqlens.numel(),), max_seqlen, dtype=torch.int32, device=cu_seqlens.device)
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward_ixdnn(
qkv[:, 0], qkv[:, 1], qkv[:, 2], cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
dropout_p, causal=causal, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode,
window_size=window_size
)
else:
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward(
qkv[:, 0], qkv[:, 1], qkv[:, 2], cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0,
use_alibi=use_alibi, alibi_mode=alibi_mode
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
ctx.dropout_p = dropout_p
ctx.max_seqlen = max_seqlen
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.use_alibi = use_alibi
ctx.alibi_mode = alibi_mode
ctx.imp_mode = imp_mode
ctx.window_size = window_size
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
if enable_ixdnn:
_flash_attn_varlen_backward_ixdnn(
dout, q, k, v, out, softmax_lse, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2],
cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen,
ctx.dropout_p, ctx.causal, ctx.use_alibi, ctx.alibi_mode, ctx.imp_mode,
ctx.window_size
)
else:
_flash_attn_varlen_backward(
dout, q, k, v, out, softmax_lse, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2],
cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen,
ctx.dropout_p, ctx.softmax_scale, ctx.causal, ctx.use_alibi, ctx.alibi_mode
)
dqkv = dqkv[..., :dout.shape[-1]] # We could have padded the head dimension
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None, None, None, None, None, None, None
class FlashAttnKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax, use_alibi, alibi_mode, imp_mode):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if alibi_slopes is not None and use_alibi == False:
warnings.warn("Parameter 'alibi_slopes' is not supported, automatically switching to the supported parameters 'use_alibi=True' and 'alibi_mode=1'", UserWarning)
use_alibi = True
imp_mode = 1
if enable_ixdnn:
cu_seqlens_q = torch.full((q.shape[0],), q.shape[1], dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.full((kv.shape[0],), kv.shape[1], dtype=torch.int32, device=kv.device)
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward_ixdnn(
q, kv[:, :, 0], kv[:, :, 1], cu_seqlens_q, cu_seqlens_k, dropout_p, causal=causal,
return_softmax=return_softmax and dropout_p > 0, use_alibi=use_alibi,
alibi_mode=alibi_mode, imp_mode=imp_mode, window_size=window_size
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state)
else:
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
q, kv[:, :, 0], kv[:, :, 1], dropout_p, softmax_scale, causal=causal,
return_softmax=return_softmax and dropout_p > 0, use_alibi=use_alibi,
alibi_mode=alibi_mode
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.use_alibi = use_alibi
ctx.alibi_mode = alibi_mode
ctx.imp_mode = imp_mode
ctx.window_size = window_size
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
def backward(ctx, dout, *args):
if enable_ixdnn:
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
else:
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dq = torch.empty_like(q)
kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
if enable_ixdnn:
_flash_attn_backward_ixdnn(
dout, q, k, v, out, softmax_lse, dq, dkv[:, :, 0], dkv[:, :, 1],
cu_seqlens_q, cu_seqlens_k, ctx.dropout_p, ctx.causal, ctx.use_alibi, ctx.alibi_mode, ctx.imp_mode, ctx.window_size
)
else:
_flash_attn_backward(
dout, q, k, v, out, softmax_lse,
dq, dkv[:, :, 0], dkv[:, :, 1], ctx.dropout_p, ctx.softmax_scale, ctx.causal, ctx.use_alibi, ctx.alibi_mode
)
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., :dout.shape[-1]]
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dq, dkv, None, None, None, None, None, None, None, None, None, None
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax,
use_alibi, alibi_mode, imp_mode):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if alibi_slopes is not None and use_alibi == False:
warnings.warn("Parameter 'alibi_slopes' is not supported, automatically switching to the supported parameters 'use_alibi=True' and 'alibi_mode=1'", UserWarning)
use_alibi = True
imp_mode = 1
if enable_ixdnn:
cu_seqlens_q = torch.full((cu_seqlens_q.numel(),), max_seqlen_q, dtype=torch.int32, device=cu_seqlens_q.device)
cu_seqlens_k = torch.full((cu_seqlens_k.numel(),), max_seqlen_k, dtype=torch.int32, device=cu_seqlens_k.device)
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward_ixdnn(
q, kv[:, 0], kv[:, 1], cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, causal=causal, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode,
window_size=window_size
)
else:
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward(
q, kv[:, 0], kv[:, 1], cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0,
use_alibi=use_alibi, alibi_mode=alibi_mode
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse,
cu_seqlens_q, cu_seqlens_k, rng_state)
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.use_alibi = use_alibi
ctx.alibi_mode = alibi_mode
ctx.imp_mode = imp_mode
ctx.window_size = window_size
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dq = torch.empty_like(q)
kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
if enable_ixdnn:
_flash_attn_varlen_backward_ixdnn(
dout, q, k, v, out, softmax_lse, dq, dkv[:, 0], dkv[:, 1],
cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k,
ctx.dropout_p, ctx.causal, ctx.use_alibi, ctx.alibi_mode, ctx.imp_mode,
ctx.window_size
)
else:
_flash_attn_varlen_backward(
dout, q, k, v, out, softmax_lse, dq, dkv[:, 0], dkv[:, 1],
cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k,
ctx.dropout_p, ctx.softmax_scale, ctx.causal, ctx.use_alibi, ctx.alibi_mode
)
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., :dout.shape[-1]]
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
class FlashAttnFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax, use_alibi, alibi_mode, imp_mode):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if alibi_slopes is not None and use_alibi == False:
warnings.warn("Parameter 'alibi_slopes' is not supported, automatically switching to the supported parameters 'use_alibi=True' and 'alibi_mode=1'", UserWarning)
use_alibi = True
imp_mode = 1
if enable_ixdnn:
cu_seqlens_q = torch.full((q.shape[0],), q.shape[1], dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.full((k.shape[0],), k.shape[1], dtype=torch.int32, device=k.device)
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward_ixdnn(
q, k, v, cu_seqlens_q, cu_seqlens_k, dropout_p, causal=causal,
return_softmax=return_softmax and dropout_p > 0, use_alibi=use_alibi,
alibi_mode = alibi_mode, imp_mode = imp_mode, window_size=window_size
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state)
else:
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
q, k, v, dropout_p, softmax_scale, causal=causal,
return_softmax=return_softmax and dropout_p > 0, use_alibi=use_alibi,
alibi_mode = alibi_mode
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.use_alibi = use_alibi
ctx.alibi_mode = alibi_mode
ctx.imp_mode = imp_mode
ctx.window_size = window_size
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
def backward(ctx, dout, *args):
if enable_ixdnn:
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
else:
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
if enable_ixdnn:
_flash_attn_backward_ixdnn(
dout, q, k, v, out, softmax_lse,
dq, dk, dv, cu_seqlens_q, cu_seqlens_k, ctx.dropout_p,
ctx.causal, ctx.use_alibi, ctx.alibi_mode, ctx.imp_mode,
ctx.window_size
)
else:
_flash_attn_backward(
dout, q, k, v, out, softmax_lse,
dq, dk, dv, ctx.dropout_p, ctx.softmax_scale, ctx.causal, ctx.use_alibi, ctx.alibi_mode
)
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., :dout.shape[-1]]
dv = dv[..., :dout.shape[-1]]
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
class FlashAttnVarlenFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax,
use_alibi, alibi_mode, imp_mode):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if alibi_slopes is not None and use_alibi == False:
warnings.warn("Parameter 'alibi_slopes' is not supported, automatically switching to the supported parameters 'use_alibi=True' and 'alibi_mode=1'", UserWarning)
use_alibi = True
imp_mode = 1
if enable_ixdnn:
cu_seqlens_q = torch.full((cu_seqlens_q.numel(),), max_seqlen_q, dtype=torch.int32, device=cu_seqlens_q.device)
cu_seqlens_k = torch.full((cu_seqlens_k.numel(),), max_seqlen_k, dtype=torch.int32, device=cu_seqlens_k.device)
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward_ixdnn(
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, causal=causal, use_alibi=use_alibi, alibi_mode = alibi_mode, imp_mode=imp_mode,
window_size=window_size
)
else:
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward(
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0,
use_alibi = use_alibi, alibi_mode = alibi_mode
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse,
cu_seqlens_q, cu_seqlens_k, rng_state)
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.use_alibi = use_alibi
ctx.alibi_mode = alibi_mode
ctx.imp_mode = imp_mode
ctx.window_size = window_size
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
if enable_ixdnn:
_flash_attn_varlen_backward_ixdnn(
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.causal, ctx.use_alibi,
ctx.alibi_mode, ctx.imp_mode, ctx.window_size
)
else:
_flash_attn_varlen_backward(
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
ctx.use_alibi, ctx.alibi_mode
)
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., :dout.shape[-1]]
dv = dv[..., :dout.shape[-1]]
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
def flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), alibi_slopes=None,
deterministic=False, return_attn_probs=False, use_alibi=False, alibi_mode=1, imp_mode=0):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of Q, K, V.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
Arguments:
qkv: (batch_size, seqlen, 3, nheads, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
use_alibi: bool. Whether to apply attention with linear bias(ALiBi). if True, causal must
be True. The calculation formula(QK^T + bias * slops_m) is shown as follows,
q1*k1 0
q2*k1 q2*k2 -1 0
q3*k1 q3*k2 q3*k3 + -2 -1 0 * slops_m
q4*k1 q4*k2 q4*k3 q4*k4 -3 -2 -1 0
q5*k1 q5*k2 q5*k3 q5*k4 q5*k5 -4 -3 -2 -1 0
alibi_mode: int. The bias mode of ALiBi, default to 1.
alibi_mode=0: alibi_mode=1:
0 -√0
-1 0 -√1 -√0
-2 -1 0 -√2 -√1 -√0
-3 -2 -1 0 -√3 -√2 -√1 -√0
-4 -3 -2 -1 0 -√4 -√3 -√2 -√1 -√0
-5 -4 -3 -2 -1 0 -√5 -√4 -√3 -√2 -√1 -√0
imp_mode: int. Support two modes of backward implementation, default to 0.
imp_mode=0(CUDNN_FATTN_BALANCE_MODE): Implement dQ reduction calculation by launching a new kernel,
which will occupy additional memory. This mode usually has better performance.
imp_mode=1(CUDNN_FATTN_LEAST_MEM_MODE): Implement dQ reduction calculation within the block,
which will not occupy additional memory. There will be a slight performance loss in this mode.
This mode can be considered in memory-limited scenarios.
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnQKVPackedFunc.apply(qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_attn_probs, use_alibi, alibi_mode, imp_mode)
def flash_attn_kvpacked_func(q, kv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), alibi_slopes=None,
deterministic=False, return_attn_probs=False, use_alibi=False, alibi_mode=1, imp_mode=0):
"""dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of K, V.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
kv: (batch_size, seqlen, 2, nheads_k, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
use_alibi: bool. Whether to apply attention with linear bias(ALiBi). if True, causal must
be True. The calculation formula(QK^T + bias * slops_m) is shown as follows,
q1*k1 0
q2*k1 q2*k2 -1 0
q3*k1 q3*k2 q3*k3 + -2 -1 0 * slops_m
q4*k1 q4*k2 q4*k3 q4*k4 -3 -2 -1 0
q5*k1 q5*k2 q5*k3 q5*k4 q5*k5 -4 -3 -2 -1 0
alibi_mode: int. The bias mode of ALiBi, default to 1.
alibi_mode=0: alibi_mode=1:
0 -√0
-1 0 -√1 -√0
-2 -1 0 -√2 -√1 -√0
-3 -2 -1 0 -√3 -√2 -√1 -√0
-4 -3 -2 -1 0 -√4 -√3 -√2 -√1 -√0
-5 -4 -3 -2 -1 0 -√5 -√4 -√3 -√2 -√1 -√0
imp_mode: int. Support two modes of backward implementation, default to 0.
imp_mode=0(CUDNN_FATTN_BALANCE_MODE): Implement dQ reduction calculation by launching a new kernel,
which will occupy additional memory. This mode usually has better performance.
imp_mode=1(CUDNN_FATTN_LEAST_MEM_MODE): Implement dQ reduction calculation within the block,
which will not occupy additional memory. There will be a slight performance loss in this mode.
This mode can be considered in memory-limited scenarios.
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnKVPackedFunc.apply(q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_attn_probs, use_alibi, alibi_mode, imp_mode)
def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), alibi_slopes=None,
deterministic=False, return_attn_probs=False, use_alibi=False, alibi_mode=1, imp_mode=0):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
v: (batch_size, seqlen, nheads_k, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
use_alibi: bool. Whether to apply attention with linear bias(ALiBi). if True, causal must
be True. The calculation formula(QK^T + bias * slops_m) is shown as follows,
q1*k1 0
q2*k1 q2*k2 -1 0
q3*k1 q3*k2 q3*k3 + -2 -1 0 * slops_m
q4*k1 q4*k2 q4*k3 q4*k4 -3 -2 -1 0
q5*k1 q5*k2 q5*k3 q5*k4 q5*k5 -4 -3 -2 -1 0
alibi_mode: int. The bias mode of ALiBi, default to 1.
alibi_mode=0: alibi_mode=1:
0 -√0
-1 0 -√1 -√0
-2 -1 0 -√2 -√1 -√0
-3 -2 -1 0 -√3 -√2 -√1 -√0
-4 -3 -2 -1 0 -√4 -√3 -√2 -√1 -√0
-5 -4 -3 -2 -1 0 -√5 -√4 -√3 -√2 -√1 -√0
imp_mode: int. Support two modes of backward implementation, default to 0.
imp_mode=0(CUDNN_FATTN_BALANCE_MODE): Implement dQ reduction calculation by launching a new kernel,
which will occupy additional memory. This mode usually has better performance.
imp_mode=1(CUDNN_FATTN_LEAST_MEM_MODE): Implement dQ reduction calculation within the block,
which will not occupy additional memory. There will be a slight performance loss in this mode.
This mode can be considered in memory-limited scenarios.
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnFunc.apply(q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_attn_probs, use_alibi, alibi_mode, imp_mode)
def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p=0.0, softmax_scale=None,
causal=False, window_size=(-1, -1), alibi_slopes=None, deterministic=False,
return_attn_probs=False, use_alibi=False, alibi_mode=1, imp_mode=0):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of Q, K, V.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
Arguments:
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into qkv.
max_seqlen: int. Maximum sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
use_alibi: bool. Whether to apply attention with linear bias(ALiBi). if True, causal must
be True. The calculation formula(QK^T + bias * slops_m) is shown as follows,
q1*k1 0
q2*k1 q2*k2 -1 0
q3*k1 q3*k2 q3*k3 + -2 -1 0 * slops_m
q4*k1 q4*k2 q4*k3 q4*k4 -3 -2 -1 0
q5*k1 q5*k2 q5*k3 q5*k4 q5*k5 -4 -3 -2 -1 0
alibi_mode: int. The bias mode of ALiBi, default to 1.
alibi_mode=0: alibi_mode=1:
0 -√0
-1 0 -√1 -√0
-2 -1 0 -√2 -√1 -√0
-3 -2 -1 0 -√3 -√2 -√1 -√0
-4 -3 -2 -1 0 -√4 -√3 -√2 -√1 -√0
-5 -4 -3 -2 -1 0 -√5 -√4 -√3 -√2 -√1 -√0
imp_mode: int. Support two modes of backward implementation, default to 0.
imp_mode=0(CUDNN_FATTN_BALANCE_MODE): Implement dQ reduction calculation by launching a new kernel,
which will occupy additional memory. This mode usually has better performance.
imp_mode=1(CUDNN_FATTN_LEAST_MEM_MODE): Implement dQ reduction calculation within the block,
which will not occupy additional memory. There will be a slight performance loss in this mode.
This mode can be considered in memory-limited scenarios.
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnVarlenQKVPackedFunc.apply(
qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, window_size,
alibi_slopes, deterministic, return_attn_probs, use_alibi, alibi_mode, imp_mode
)
def flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1),
alibi_slopes=None, deterministic=False, return_attn_probs=False,
use_alibi=False, alibi_mode=1, imp_mode=0):
"""dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of K, V.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
use_alibi: bool. Whether to apply attention with linear bias(ALiBi). if True, causal must
be True. The calculation formula(QK^T + bias * slops_m) is shown as follows,
q1*k1 0
q2*k1 q2*k2 -1 0
q3*k1 q3*k2 q3*k3 + -2 -1 0 * slops_m
q4*k1 q4*k2 q4*k3 q4*k4 -3 -2 -1 0
q5*k1 q5*k2 q5*k3 q5*k4 q5*k5 -4 -3 -2 -1 0
alibi_mode: int. The bias mode of ALiBi, default to 1.
alibi_mode=0: alibi_mode=1:
0 -√0
-1 0 -√1 -√0
-2 -1 0 -√2 -√1 -√0
-3 -2 -1 0 -√3 -√2 -√1 -√0
-4 -3 -2 -1 0 -√4 -√3 -√2 -√1 -√0
-5 -4 -3 -2 -1 0 -√5 -√4 -√3 -√2 -√1 -√0
imp_mode: int. Support two modes of backward implementation, default to 0.
imp_mode=0(CUDNN_FATTN_BALANCE_MODE): Implement dQ reduction calculation by launching a new kernel,
which will occupy additional memory. This mode usually has better performance.
imp_mode=1(CUDNN_FATTN_LEAST_MEM_MODE): Implement dQ reduction calculation within the block,
which will not occupy additional memory. There will be a slight performance loss in this mode.
This mode can be considered in memory-limited scenarios.
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnVarlenKVPackedFunc.apply(
q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal, window_size, alibi_slopes,
deterministic, return_attn_probs, use_alibi, alibi_mode, imp_mode
)
def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1),
alibi_slopes=None, deterministic=False, return_attn_probs=False,
use_alibi=False, alibi_mode=1, imp_mode=0):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
than Q. Note that the number of heads in K, V must be divisible by the number of heads in Q.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
use_alibi: bool. Whether to apply attention with linear bias(ALiBi). if True, causal must
be True. The calculation formula(QK^T + bias * slops_m) is shown as follows,
q1*k1 0
q2*k1 q2*k2 -1 0
q3*k1 q3*k2 q3*k3 + -2 -1 0 * slops_m
q4*k1 q4*k2 q4*k3 q4*k4 -3 -2 -1 0
q5*k1 q5*k2 q5*k3 q5*k4 q5*k5 -4 -3 -2 -1 0
alibi_mode: int. The bias mode of ALiBi, default to 1.
alibi_mode=0: alibi_mode=1:
0 -√0
-1 0 -√1 -√0
-2 -1 0 -√2 -√1 -√0
-3 -2 -1 0 -√3 -√2 -√1 -√0
-4 -3 -2 -1 0 -√4 -√3 -√2 -√1 -√0
-5 -4 -3 -2 -1 0 -√5 -√4 -√3 -√2 -√1 -√0
imp_mode: int. Support two modes of backward implementation, default to 0.
imp_mode=0(CUDNN_FATTN_BALANCE_MODE): Implement dQ reduction calculation by launching a new kernel,
which will occupy additional memory. This mode usually has better performance.
imp_mode=1(CUDNN_FATTN_LEAST_MEM_MODE): Implement dQ reduction calculation within the block,
which will not occupy additional memory. There will be a slight performance loss in this mode.
This mode can be considered in memory-limited scenarios.
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnVarlenFunc.apply(
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic,
return_attn_probs, use_alibi, alibi_mode, imp_mode
)

View File

@@ -0,0 +1,832 @@
"""
*Experimental* implementation of FlashAttention in Triton.
Tested with triton==2.0.0.dev20221202.
Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
other than 64:
https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
We'll update this implementation with the new Triton backend once this is fixed.
We use the FlashAttention implementation from Phil Tillet a starting point.
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
Changes:
- Implement both causal and non-causal attention.
- Implement both self-attention and cross-attention.
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
- Support attention bias.
- Speed up the forward pass a bit, and only store the LSE instead of m and l.
- Make the backward for d=128 much faster by reducing register spilling.
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
small batch size * nheads.
Caution:
- This is an *experimental* implementation. The forward pass should be quite robust but
I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
- This implementation has only been tested on A100.
- If you plan to use headdim other than 64 and 128, you should test for race conditions
(due to the Triton compiler), as done in tests/test_flash_attn.py
"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
that there are none left for other head dimensions.
Differences between this Triton version and the CUDA version:
- Triton version doesn't support dropout.
- Triton forward is generally faster than CUDA forward, while Triton backward is
generally slower than CUDA backward. Overall Triton forward + backward is slightly slower
than CUDA forward + backward.
- Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
- Triton version supports attention bias, while CUDA version doesn't.
"""
import math
import torch
import triton
import triton.language as tl
# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1),
# # This config has a race condition when EVEN_M == False, disabling it for now.
# # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
# ],
# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM']
# )
@triton.heuristics(
{
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
}
)
@triton.jit
def _fwd_kernel(
Q, K, V, Bias, Out,
Lse, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
softmax_scale,
stride_qb, stride_qh, stride_qm,
stride_kb, stride_kh, stride_kn,
stride_vb, stride_vh, stride_vn,
stride_bb, stride_bh, stride_bm,
stride_ob, stride_oh, stride_om,
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
BIAS_TYPE: tl.constexpr,
IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hb = tl.program_id(1)
off_b = off_hb // nheads
off_h = off_hb % nheads
# off_b = tl.program_id(1)
# off_h = tl.program_id(2)
# off_hb = off_b * nheads + off_h
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# Initialize pointers to Q, K, V
# Adding parenthesis around indexing might use int32 math instead of int64 math?
# https://github.com/openai/triton/issues/741
# I'm seeing a tiny bit of difference (5-7us)
q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
if BIAS_TYPE == 'vector':
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
elif BIAS_TYPE == 'matrix':
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])
# initialize pointer to m and l
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
# load q: it will stay in SRAM throughout
# [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
# tl.load(q_ptrs), we get the wrong output!
if EVEN_M & EVEN_N:
if EVEN_HEADDIM:
q = tl.load(q_ptrs)
else:
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
else:
q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
other=0.0)
# loop over k, v and update accumulator
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
for start_n in range(0, end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn)
else:
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0)
else:
k = tl.load(k_ptrs + start_n * stride_kn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True)
# Trying to combine the two masks seem to make the result wrong
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
if IS_CAUSAL:
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
if BIAS_TYPE != 'none':
if BIAS_TYPE == 'vector':
if EVEN_N:
bias = tl.load(b_ptrs + start_n).to(tl.float32)
else:
bias = tl.load(b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0).to(tl.float32)
bias = bias[None, :]
elif BIAS_TYPE == 'matrix':
if EVEN_M & EVEN_N:
bias = tl.load(b_ptrs + start_n).to(tl.float32)
else:
bias = tl.load(b_ptrs + start_n,
mask=(offs_m[:, None] < seqlen_q)
& ((start_n + offs_n)[None, :] < seqlen_k),
other=0.0).to(tl.float32)
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
# to multiply with softmax_scale here.
qk = qk * softmax_scale + bias
m_ij = tl.maximum(tl.max(qk, 1), lse_i)
p = tl.exp(qk - m_ij[:, None])
else:
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
p = tl.exp(qk * softmax_scale - m_ij[:, None])
l_ij = tl.sum(p, 1)
# scale acc_o
acc_o_scale = tl.exp(m_i - m_ij)
# # -- update output accumulator --
# BUG: have to store and immediately load
tl.store(t_ptrs, acc_o_scale)
acc_o_scale = tl.load(t_ptrs)
acc_o = acc_o * acc_o_scale[:, None]
# update acc_o
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
if EVEN_HEADDIM:
v = tl.load(v_ptrs + start_n * stride_vn)
else:
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0)
else:
v = tl.load(v_ptrs + start_n * stride_vn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0)
p = p.to(v.dtype)
acc_o += tl.dot(p, v)
# -- update statistics
m_i = m_ij
l_i_new = tl.exp(lse_i - m_ij) + l_ij
lse_i = m_ij + tl.log(l_i_new)
o_scale = tl.exp(m_i - lse_i)
# BUG: have to store and immediately load
tl.store(t_ptrs, o_scale)
o_scale = tl.load(t_ptrs)
acc_o = acc_o * o_scale[:, None]
# rematerialize offsets to save registers
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# write back l and m
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
tl.store(lse_ptrs, lse_i)
# initialize pointers to output
offs_d = tl.arange(0, BLOCK_HEADDIM)
out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :])
if EVEN_M:
if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o)
else:
tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
else:
if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
else:
tl.store(out_ptrs, acc_o,
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
@triton.jit
def _bwd_preprocess_do_o_dot(
Out, DO, Delta,
stride_ob, stride_oh, stride_om,
stride_dob, stride_doh, stride_dom,
nheads, seqlen_q, seqlen_q_rounded, headdim,
BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,
):
start_m = tl.program_id(0)
off_hb = tl.program_id(1)
off_b = off_hb // nheads
off_h = off_hb % nheads
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# load
o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32)
do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :],
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32)
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
@triton.jit
def _bwd_store_dk_dv(
dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
):
# [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
# if we just call tl.store(dv_ptrs), there's a race condition
if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
else:
tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
else:
if EVEN_HEADDIM:
tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
else:
tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
@triton.jit
def _bwd_kernel_one_col_block(
start_n,
Q, K, V, Bias,
DO, DQ, DK, DV,
LSE, D,
softmax_scale,
stride_qm, stride_kn, stride_vn, stride_bm,
stride_dom, stride_dqm, stride_dkn, stride_dvn,
seqlen_q, seqlen_k, headdim,
ATOMIC_ADD: tl.constexpr,
BIAS_TYPE: tl.constexpr,
IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
# We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
# initialize row/col offsets
offs_qm = begin_m + tl.arange(0, BLOCK_M)
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_m = tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# initialize pointers to value-like data
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
if BIAS_TYPE == 'vector':
b_ptrs = Bias + offs_n
elif BIAS_TYPE == 'matrix':
b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
# initialize dv and dk
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
# There seems to be some problem with Triton pipelining that makes results wrong for
# headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop
# may have zero step, and pipelining with the bias matrix could screw it up.
# So we just exit early.
if begin_m >= seqlen_q:
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
_bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
return
# k and v stay in SRAM throughout
# [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
# if we just call tl.load(k_ptrs), we get the wrong output!
if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
else:
k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
else:
k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0)
v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0)
# loop over rows
num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
start_m = tl.multiple_of(start_m, BLOCK_M)
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
# Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117)
if EVEN_M & EVEN_HEADDIM:
q = tl.load(q_ptrs)
else:
if EVEN_HEADDIM:
q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
else:
q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
& (offs_d[None, :] < headdim), other=0.0)
# recompute p = softmax(qk, dim=-1).T
qk = tl.dot(q, k, trans_b=True)
# Trying to combine the two masks seem to make the result wrong
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
if IS_CAUSAL:
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
if BIAS_TYPE != 'none':
tl.debug_barrier() # Race condition otherwise
if BIAS_TYPE == 'vector':
if EVEN_N:
bias = tl.load(b_ptrs).to(tl.float32)
else:
bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)
bias = bias[None, :]
elif BIAS_TYPE == 'matrix':
if EVEN_M & EVEN_N:
bias = tl.load(b_ptrs).to(tl.float32)
else:
bias = tl.load(b_ptrs,
mask=(offs_m_curr[:, None] < seqlen_q)
& (offs_n[None, :] < seqlen_k),
other=0.0).to(tl.float32)
qk = qk * softmax_scale + bias
# There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
# Also wrong for headdim=64.
if not (EVEN_M & EVEN_HEADDIM):
tl.debug_barrier()
lse_i = tl.load(LSE + offs_m_curr)
if BIAS_TYPE == 'none':
p = tl.exp(qk * softmax_scale - lse_i[:, None])
else:
p = tl.exp(qk - lse_i[:, None])
# compute dv
# [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
# do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs
# in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512,
# the output is correct.
if EVEN_M & EVEN_HEADDIM:
do = tl.load(do_ptrs)
else:
# [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
& (offs_d[None, :] < headdim), other=0.0)
# if EVEN_M:
# if EVEN_HEADDIM:
# do = tl.load(do_ptrs)
# else:
# do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
# else:
# if EVEN_HEADDIM:
# do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
# else:
# do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
# & (offs_d[None, :] < headdim), other=0.0)
dv += tl.dot(p.to(do.dtype), do, trans_a=True)
# compute dp = dot(v, do)
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
# Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
# Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
if not (EVEN_M & EVEN_HEADDIM):
tl.debug_barrier()
dp = tl.dot(do, v, trans_b=True)
# There's a race condition for headdim=48
if not EVEN_HEADDIM:
tl.debug_barrier()
# compute ds = p * (dp - delta[:, None])
# Putting the subtraction after the dp matmul (instead of before) is slightly faster
Di = tl.load(D + offs_m_curr)
# Converting ds to q.dtype here reduces register pressure and makes it much faster
# for BLOCK_HEADDIM=128
ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
# compute dk = dot(ds.T, q)
dk += tl.dot(ds, q, trans_a=True)
# compute dq
if not (EVEN_M & EVEN_HEADDIM): # Otherewise there's a race condition when BIAS_TYPE='matrix'
tl.debug_barrier()
if not ATOMIC_ADD:
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
dq += tl.dot(ds, k)
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
else:
if EVEN_HEADDIM:
dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0,
eviction_policy="evict_last")
dq += tl.dot(ds, k)
tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q,
eviction_policy="evict_last")
else:
dq = tl.load(dq_ptrs,
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
other=0.0, eviction_policy="evict_last")
dq += tl.dot(ds, k)
tl.store(dq_ptrs, dq,
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
eviction_policy="evict_last")
else: # If we're parallelizing across the seqlen_k dimension
dq = tl.dot(ds, k)
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
tl.atomic_add(dq_ptrs, dq)
else:
if EVEN_HEADDIM:
tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
else:
tl.atomic_add(dq_ptrs, dq,
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
# increment pointers
dq_ptrs += BLOCK_M * stride_dqm
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_dom
if BIAS_TYPE == 'matrix':
b_ptrs += BLOCK_M * stride_bm
# write-back
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
_bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
def init_to_zero(name):
return lambda nargs: nargs[name].zero_()
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
# Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now
# # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
],
key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'],
)
@triton.heuristics(
{
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
}
)
@triton.jit
def _bwd_kernel(
Q, K, V, Bias,
DO, DQ, DK, DV,
LSE, D,
softmax_scale,
stride_qb, stride_qh, stride_qm,
stride_kb, stride_kh, stride_kn,
stride_vb, stride_vh, stride_vn,
stride_bb, stride_bh, stride_bm,
stride_dob, stride_doh, stride_dom,
stride_dqb, stride_dqh, stride_dqm,
stride_dkb, stride_dkh, stride_dkn,
stride_dvb, stride_dvh, stride_dvn,
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
BIAS_TYPE: tl.constexpr,
IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr,
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
off_hb = tl.program_id(1)
off_b = off_hb // nheads
off_h = off_hb % nheads
# offset pointers for batch/head
Q += off_b * stride_qb + off_h * stride_qh
K += off_b * stride_kb + off_h * stride_kh
V += off_b * stride_vb + off_h * stride_vh
DO += off_b * stride_dob + off_h * stride_doh
DQ += off_b * stride_dqb + off_h * stride_dqh
DK += off_b * stride_dkb + off_h * stride_dkh
DV += off_b * stride_dvb + off_h * stride_dvh
if BIAS_TYPE != 'none':
Bias += off_b * stride_bb + off_h * stride_bh
# pointer to row-wise quantities in value-like data
D += off_hb * seqlen_q_rounded
LSE += off_hb * seqlen_q_rounded
if not SEQUENCE_PARALLEL:
num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
for start_n in range(0, num_block_n):
_bwd_kernel_one_col_block(
start_n,
Q, K, V, Bias,
DO, DQ, DK, DV,
LSE, D,
softmax_scale,
stride_qm, stride_kn, stride_vn, stride_bm,
stride_dom, stride_dqm, stride_dkn, stride_dvn,
seqlen_q, seqlen_k, headdim,
ATOMIC_ADD=False,
BIAS_TYPE=BIAS_TYPE,
IS_CAUSAL=IS_CAUSAL,
BLOCK_HEADDIM=BLOCK_HEADDIM,
EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
)
else:
start_n = tl.program_id(0)
_bwd_kernel_one_col_block(
start_n,
Q, K, V, Bias,
DO, DQ, DK, DV,
LSE, D,
softmax_scale,
stride_qm, stride_kn, stride_vn, stride_bm,
stride_dom, stride_dqm, stride_dkn, stride_dvn,
seqlen_q, seqlen_k, headdim,
ATOMIC_ADD=True,
BIAS_TYPE=BIAS_TYPE,
IS_CAUSAL=IS_CAUSAL,
BLOCK_HEADDIM=BLOCK_HEADDIM,
EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
)
def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
# shape constraints
batch, seqlen_q, nheads, d = q.shape
_, seqlen_k, _, _ = k.shape
assert k.shape == (batch, seqlen_k, nheads, d)
assert v.shape == (batch, seqlen_k, nheads, d)
assert d <= 128, 'FlashAttention only support head dimensions up to 128'
assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'
assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16'
assert q.is_cuda and k.is_cuda and v.is_cuda
softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
has_bias = bias is not None
bias_type = 'none'
if has_bias:
assert bias.dtype in [q.dtype, torch.float]
assert bias.is_cuda
assert bias.dim() == 4
if bias.stride(-1) != 1:
bias = bias.contiguous()
if bias.shape[2:] == (1, seqlen_k):
bias_type = 'vector'
elif bias.shape[2:] == (seqlen_q, seqlen_k):
bias_type = 'matrix'
else:
raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
' or (seqlen_q, seqlen_k)')
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
o = torch.empty_like(q)
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
BLOCK = 128
num_warps = 4 if d <= 64 else 8
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
_fwd_kernel[grid](
q, k, v, bias, o,
lse, tmp,
softmax_scale,
q.stride(0), q.stride(2), q.stride(1),
k.stride(0), k.stride(2), k.stride(1),
v.stride(0), v.stride(2), v.stride(1),
*bias_strides,
o.stride(0), o.stride(2), o.stride(1),
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,
seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations)
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
bias_type, causal, BLOCK_HEADDIM,
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return o, lse, softmax_scale # softmax_scale could have been updated
def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):
# Make sure that the last dimension is contiguous
if do.stride(-1) != 1:
do = do.contiguous()
batch, seqlen_q, nheads, d = q.shape
_, seqlen_k, _, _ = k.shape
# assert d in {16, 32, 64, 128}
assert d <= 128
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
assert lse.shape == (batch, nheads, seqlen_q_rounded)
assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
# dq_accum = torch.zeros_like(q, dtype=torch.float32)
dq_accum = torch.empty_like(q, dtype=torch.float32)
delta = torch.empty_like(lse)
# delta = torch.zeros_like(lse)
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
_bwd_preprocess_do_o_dot[grid](
o, do, delta,
o.stride(0), o.stride(2), o.stride(1),
do.stride(0), do.stride(2), do.stride(1),
nheads, seqlen_q, seqlen_q_rounded, d,
BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM,
)
has_bias = bias is not None
bias_type = 'none'
if has_bias:
assert bias.dtype in [q.dtype, torch.float]
assert bias.is_cuda
assert bias.dim() == 4
assert bias.stride(-1) == 1
if bias.shape[2:] == (1, seqlen_k):
bias_type = 'vector'
elif bias.shape[2:] == (seqlen_q, seqlen_k):
bias_type = 'matrix'
else:
raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
' or (seqlen_q, seqlen_k)')
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
# BLOCK_M = 128
# BLOCK_N = 64
# num_warps = 4
grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1,
batch * nheads)
_bwd_kernel[grid](
q, k, v, bias,
do, dq_accum, dk, dv,
lse, delta,
softmax_scale,
q.stride(0), q.stride(2), q.stride(1),
k.stride(0), k.stride(2), k.stride(1),
v.stride(0), v.stride(2), v.stride(1),
*bias_strides,
do.stride(0), do.stride(2), do.stride(1),
dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1),
dk.stride(0), dk.stride(2), dk.stride(1),
dv.stride(0), dv.stride(2), dv.stride(1),
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,
seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations)
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
bias_type, causal, BLOCK_HEADDIM,
# SEQUENCE_PARALLEL=False,
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
# num_warps=num_warps,
# num_stages=1,
)
dq.copy_(dq_accum)
class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
"""
qkv: (batch, seqlen, 3, nheads, headdim)
bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
"""
# Make sure that the last dimension is contiguous
if qkv.stride(-1) != 1:
qkv = qkv.contiguous()
o, lse, ctx.softmax_scale = _flash_attn_forward(
qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal,
softmax_scale=softmax_scale
)
ctx.save_for_backward(qkv, o, lse, bias)
ctx.causal = causal
return o
@staticmethod
def backward(ctx, do):
qkv, o, lse, bias = ctx.saved_tensors
assert not ctx.needs_input_grad[1], 'FlashAttention does not support bias gradient yet'
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
with torch.inference_mode():
dqkv = torch.empty_like(qkv)
_flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse,
dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
return dqkv, None, None, None
flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
class FlashAttnKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
"""
q: (batch, seqlen_q, nheads, headdim)
kv: (batch, seqlen_k, 2, nheads, headdim)
bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
"""
# Make sure that the last dimension is contiguous
q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]
o, lse, ctx.softmax_scale = _flash_attn_forward(
q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale
)
ctx.save_for_backward(q, kv, o, lse, bias)
ctx.causal = causal
return o
@staticmethod
def backward(ctx, do):
q, kv, o, lse, bias = ctx.saved_tensors
if len(ctx.needs_input_grad) >= 3:
assert not ctx.needs_input_grad[2], 'FlashAttention does not support bias gradient yet'
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
with torch.inference_mode():
dq = torch.empty_like(q)
dkv = torch.empty_like(kv)
_flash_attn_backward(do, q, kv[:, :, 0], kv[:, :, 1], o, lse,
dq, dkv[:, :, 0], dkv[:, :, 1],
bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
return dq, dkv, None, None, None
flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
class FlashAttnFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
"""
q: (batch_size, seqlen_q, nheads, headdim)
k, v: (batch_size, seqlen_k, nheads, headdim)
bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
"""
# Make sure that the last dimension is contiguous
q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
o, lse, ctx.softmax_scale = _flash_attn_forward(
q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale
)
ctx.save_for_backward(q, k, v, o, lse, bias)
ctx.causal = causal
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, lse, bias = ctx.saved_tensors
assert not ctx.needs_input_grad[3], 'FlashAttention does not support bias gradient yet'
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
with torch.inference_mode():
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
_flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv,
bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
return dq, dk, dv, None, None, None
flash_attn_func = FlashAttnFunc.apply

View File

@@ -0,0 +1,276 @@
# [2022-10-23] Downloaded from https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
# for benchmarking.
# We fixed a few dtype cast to make it work for bf16
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
"""
import pytest
import torch
import triton
import triton.language as tl
@triton.jit
def _fwd_kernel(
Q, K, V, sm_scale,
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
t_ptrs = TMP + off_hz * N_CTX + offs_m
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# load q: it will stay in SRAM throughout
q = tl.load(q_ptrs)
# loop over k, v and update accumulator
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs + start_n * stride_kn)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True)
qk *= sm_scale
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
tl.store(t_ptrs, acc_scale)
acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs + start_n * stride_vk)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# rematerialize offsets to save registers
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# write back l and m
l_ptrs = L + off_hz * N_CTX + offs_m
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(l_ptrs, l_i)
tl.store(m_ptrs, m_i)
# initialize pointers to output
offs_n = tl.arange(0, BLOCK_DMODEL)
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
@triton.jit
def _bwd_preprocess(
Out, DO, L,
NewDO, Delta,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)
# load
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
denom = tl.load(L + off_m).to(tl.float32)
# compute
do = do / denom[:, None]
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
tl.store(Delta + off_m, delta)
@triton.jit
def _bwd_kernel(
Q, K, V, sm_scale, Out, DO,
DQ, DK, DV,
L, M,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
off_hz = tl.program_id(0)
off_z = off_hz // H
off_h = off_hz % H
# offset pointers for batch/head
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_qz + off_h * stride_qh
V += off_z * stride_qz + off_h * stride_qh
DO += off_z * stride_qz + off_h * stride_qh
DQ += off_z * stride_qz + off_h * stride_qh
DK += off_z * stride_qz + off_h * stride_qh
DV += off_z * stride_qz + off_h * stride_qh
for start_n in range(0, num_block):
lo = start_n * BLOCK_M
# initialize row/col offsets
offs_qm = lo + tl.arange(0, BLOCK_M)
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
offs_m = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_DMODEL)
# initialize pointers to value-like data
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
m_ptrs = M + off_hz * N_CTX
# initialize dv amd dk
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# k and v stay in SRAM throughout
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
# loop over rows
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
q = tl.load(q_ptrs)
# recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, k, trans_b=True)
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
m = tl.load(m_ptrs + offs_m_curr)
p = tl.exp(qk * sm_scale - m[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(p.to(do.dtype), do, trans_a=True)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, v, trans_b=True)
# compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale
# compute dk = dot(ds.T, q)
dk += tl.dot(ds.to(q.dtype), q, trans_a=True)
# # compute dq
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
dq += tl.dot(ds.to(k.dtype), k)
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
# # increment pointers
dq_ptrs += BLOCK_M * stride_qm
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm
# write-back
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, sm_scale):
BLOCK = 128
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
_fwd_kernel[grid](
q, k, v, sm_scale,
tmp, L, m,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk, num_warps=num_warps,
num_stages=1,
)
ctx.save_for_backward(q, k, v, o, L, m)
ctx.BLOCK = BLOCK
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = Lk
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, l, m = ctx.saved_tensors
do = do.contiguous()
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
do_scaled = torch.empty_like(do)
delta = torch.empty_like(l)
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do, l,
do_scaled, delta,
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
# NOTE: kernel currently buggy for other values of `num_warps`
num_warps = 8
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
l, m,
delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
ctx.grid[0],
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,
num_stages=1,
)
return dq.to(q.dtype), dk, dv, None
attention = _attention.apply

View File

@@ -0,0 +1,136 @@
import math
import torch
import torch.nn as nn
from einops import rearrange
import hydra
from flash_attn.flash_blocksparse_attn_interface import flash_blocksparse_attn_func
from flash_attn.flash_blocksparse_attn_interface import convert_blockmask
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
class FlashBlocksparseAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_temp: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.1)
"""
def __init__(self, sparsity_config, softmax_temp=None, attention_dropout=0.0,
max_seq_length=2048, device=None, dtype=None):
super().__init__()
self.sparsity_config = hydra.utils.instantiate(sparsity_config)
self.softmax_temp = softmax_temp
self.dropout_p = attention_dropout
# initialize sparse layout and register as buffer
max_seq_length = ((max_seq_length + 256 - 1) // 256) * 256
layout = self.sparsity_config.make_layout(max_seq_length)
self.register_buffer("layout", layout)
blockmask_converted = convert_blockmask(self.layout, causal=False)
self.register_buffer("blockmask_converted", blockmask_converted)
# logger.info(f'Attention class {self.__class__}: saving={self.layout.float().mean()}')
def forward(self, qkv, attn_mask=None, key_padding_mask=None, causal=False, cu_seqlens=None,
max_s=None, need_weights=False, convert_mask=True):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
attn_mask: An implementation of BaseMask that encodes where each
query can attend to
key_padding_mask: An implementation of BaseMask that encodes how
many query each sequence in the batch consists of
"""
assert not need_weights
assert attn_mask is None
assert qkv.dtype == torch.float16
assert qkv.is_cuda
if cu_seqlens is None:
batch_size = qkv.shape[0]
seqlen = qkv.shape[1]
# Convert mask to take a subset
seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256
assert seqlen_rounded // 16 <= self.layout.shape[0], seqlen_rounded // 256 <= self.layout.shape[1]
blockmask = self.layout[:seqlen_rounded // 16, :seqlen_rounded // 256]
if key_padding_mask is None:
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
max_s = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device)
output = flash_blocksparse_attn_func(
qkv, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
else:
key_padding_mask_bool = key_padding_mask.bool_matrix
nheads = qkv.shape[-2]
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool)
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
output_unpad = flash_blocksparse_attn_func(
x_unpad, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal
)
output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
indices, batch_size, seqlen),
'b s (h d) -> b s h d', h=nheads)
else:
assert max_s is not None
seqlen = max_s
# Convert mask to take a subset
seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256
assert seqlen_rounded // 16 <= self.layout.shape[0], seqlen_rounded // 256 <= self.layout.shape[1]
blockmask = self.layout[:seqlen_rounded // 16, :seqlen_rounded // 256]
if convert_mask:
output = flash_blocksparse_attn_func(
qkv, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal
)
else:
output = flash_blocksparse_attn_func(
qkv, cu_seqlens, self.blockmask_converted, self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal,
convert_mask=False,
)
return output, None
class FlashBlocksparseMHA(nn.Module):
def __init__(self, embed_dim, num_heads, sparsity_config, bias=True, batch_first=True,
attention_dropout=0.0, causal=False, max_seq_length=2048,
device=None, dtype=None, **kwargs) -> None:
assert batch_first
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.embed_dim = embed_dim
self.causal = causal
self.num_heads = num_heads
assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
self.head_dim = self.embed_dim // num_heads
assert self.head_dim in [16, 32, 64], "Only support head_dim == 16, 32, or 64"
self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
self.inner_attn = FlashBlocksparseAttention(
sparsity_config, attention_dropout=attention_dropout,
max_seq_length=max_seq_length, **factory_kwargs
)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
def forward(self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None,
need_weights=False):
qkv = self.Wqkv(x)
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask,
need_weights=need_weights, causal=self.causal)
return self.out_proj(rearrange(context, 'b s h d -> b s (h d)')), attn_weights

View File

@@ -0,0 +1,142 @@
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py
import torch
import torch.nn as nn
import flash_attn_cuda
def convert_blockmask(blockmask, causal):
"""Convert from the 0-1 format to the format used by the CUDA code.
0 means the block is skipped.
nonzero means the block is not skipped.
Argument:
blockmask: (row, col): a 0-1 tensor
Return:
blockmask_converted: (col, row), dtype torch.int32: for each column, it contains the row
indices of the nonzero blocks, padded with -1 to reach length @row.
The indices are multiplied by 4, with the smallest bit used to encode whether
it is the first nonzero in its row, and the 2nd smallest bit to encode whether it is
the last nonzero in its row..
"""
assert not causal
# TD [2022-05-13]: The indexing and sorting is very tricky
nrow, ncol = blockmask.shape
# Sort does not support bool on CUDA
blockmask = blockmask.to(dtype=torch.uint8)
nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=0, stable=True, descending=True)
nonzero_unsorted_rowidx = nonzero_sorted_rowidx.argsort(dim=0)
last_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True).indices[:, -1]
last_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
torch.arange(nrow, device=blockmask.device), last_nonzero_col_per_row
]
first_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True, descending=True).indices[:, 0]
first_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
torch.arange(nrow, device=blockmask.device), first_nonzero_col_per_row
]
nonzero_idx = nonzero_sorted_rowidx * 4
nonzero_idx[last_nonzero_col_per_row_after_sort, last_nonzero_col_per_row] += 2
nonzero_idx[first_nonzero_col_per_row_after_sort, first_nonzero_col_per_row] += 1
nonzero_idx[nonzero_val == 0] = -1
return nonzero_idx.T.contiguous().to(dtype=torch.int32)
def _flash_blocksparse_attn_forward(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale,
causal, return_softmax):
context, softmax_lse, *rest = flash_attn_cuda.fwd_block(qkv, cu_seqlens, blockmask, dropout_p,
max_s, softmax_scale, causal,
return_softmax, None)
# if context.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
S_dmask = rest[0] if return_softmax else None
return context, softmax_lse, S_dmask
def _flash_blocksparse_attn_backward(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, blockmask,
dropout_p, max_s, softmax_scale, causal):
dqkv, dp, softmax_d = flash_attn_cuda.bwd_block(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens,
blockmask, dropout_p, softmax_scale, max_s,
causal, None)
# if dqkv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
return dqkv
class FlashBlocksparseAttnFun(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward(
qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal=causal,
return_softmax=False
)
ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state)
ctx.dropout_p = dropout_p
ctx.max_s = max_s
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return context
@staticmethod
def backward(ctx, dout):
qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
# S_dmask is None, temporarily use another tensor just to get it running
dqkv = _flash_blocksparse_attn_backward(
dout, qkv, context, context, softmax_lse, cu_seqlens, blockmask, ctx.dropout_p,
ctx.max_s, ctx.softmax_scale, ctx.causal
)
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None, None
# We duplicate code to return both the output and the softmax for testing
# Returning both makes backward a bit slower, so we want to keep using the other version for speed.
class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):
# Save rng_state because the backward pass is gonna regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward(
qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal=causal,
return_softmax=True
)
ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state)
ctx.dropout_p = dropout_p
ctx.max_s = max_s
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return context, S_dmask, softmax_lse
@staticmethod
def backward(ctx, dout, _dS_dmask_ignored, _dsoftmax_sum_ignored):
qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dqkv = _flash_blocksparse_attn_backward(
dout, qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, ctx.dropout_p,
ctx.max_s, ctx.softmax_scale, ctx.causal
)
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None
def flash_blocksparse_attn_func(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale=None,
causal=False, return_attn_probs=False, convert_mask=True):
"""dropout_p should be set to 0.0 during evaluation
"""
func = FlashBlocksparseAttnFun if not return_attn_probs else FlashBlocksparseAttnFunWithS
if convert_mask:
blockmask = convert_blockmask(blockmask, causal=causal)
return func.apply(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal)

View File

@@ -0,0 +1,205 @@
# [2022-10-23] Copied from https://github.com/NVIDIA/apex/blob/master/apex/transformer/functional/fused_softmax.py
# for benchmarking.
# We added support for seqlen=2k and seqlen=4k
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from apex._autocast_utils import _cast_if_autocast_enabled
from apex.transformer.enums import AttnMaskType
from fused_softmax_lib import scaled_masked_softmax_forward, scaled_masked_softmax_backward
from fused_softmax_lib import scaled_masked_softmax_get_batch_per_block
from fused_softmax_lib import scaled_upper_triang_masked_softmax_forward, scaled_upper_triang_masked_softmax_backward
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, scale):
scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax_forward(
inputs, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_upper_triang_masked_softmax_backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None
def scaled_upper_triang_masked_softmax(inputs, _, scale):
b, np, sq, sk = inputs.size()
assert sq == sk, "causal mask is only for self attention"
# Reshaping input to 3D tensor (attn_batches, sq, sk)
inputs = inputs.view(-1, sq, sk)
args = _cast_if_autocast_enabled(inputs, scale)
with torch.cuda.amp.autocast(enabled=False):
probs = ScaledUpperTriangMaskedSoftmax.apply(*args)
return probs.view(b, np, sq, sk)
# NOTE (mkozuki): `ScaledMaskedSoftmax` somehow doesn't work well with `torch.cuda.amp.custom_fwd`.
# Without `cast_inputs` kwarg, somehow inputs are not cast to dtype used in the autocast context.
# So I needed to manually write two `torch.autograd.Function` inheritances.
# Fused operation which performs following three operations in sequence
# 1. Scale the tensor.
# 2. Apply the mask.
# 3. Perform softmax.
class ScaledMaskedSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, mask, scale):
scale_t = torch.tensor([scale])
softmax_results = scaled_masked_softmax_forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_masked_softmax_backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None, None
def scaled_masked_softmax(inputs, mask, scale):
# input is 4D tensor (b, np, sq, sk)
args = _cast_if_autocast_enabled(inputs, mask, scale)
with torch.cuda.amp.autocast(enabled=False):
return ScaledMaskedSoftmax.apply(*args)
class FusedScaleMaskSoftmax(torch.nn.Module):
"""
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal)
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def __init__(
self,
input_in_fp16,
input_in_bf16,
attn_mask_type,
scaled_masked_softmax_fusion,
mask_func,
softmax_in_fp32,
scale,
):
super().__init__()
self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
if self.input_in_fp16 and self.input_in_bf16:
raise RuntimeError(
"both fp16 and bf16 flags cannot be active at the same time."
)
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
if not (self.scale is None or softmax_in_fp32):
raise RuntimeError("softmax should be in fp32 when scaled")
if self.scaled_masked_softmax_fusion:
if self.attn_mask_type == AttnMaskType.causal:
self.fused_softmax_func = scaled_upper_triang_masked_softmax
elif self.attn_mask_type == AttnMaskType.padding:
self.fused_softmax_func = scaled_masked_softmax
else:
raise ValueError("Invalid attn_mask_type.")
def forward(self, input, mask):
# [b, np, sq, sk]
assert input.dim() == 4
if self.is_kernel_available(mask, *input.size()):
return self.forward_fused_softmax(input, mask)
else:
return self.forward_torch_softmax(input, mask)
def is_kernel_available(self, mask, b, np, sq, sk):
attn_batches = b * np
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and (
self.attn_mask_type == AttnMaskType.causal
or (self.attn_mask_type == AttnMaskType.padding and mask is not None)
)
and 16 < sk <= 8192 # sk must be 16 ~ 8192
and sq % 4 == 0 # sq must be divisor of 4
and sk % 4 == 0 # sk must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 8192:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
if self.attn_mask_type == AttnMaskType.causal:
if attn_batches % batch_per_block == 0:
return True
else:
if sq % batch_per_block == 0:
return True
return False
def forward_fused_softmax(self, input, mask):
# input.shape = [b, np, sq, sk]
scale = self.scale if self.scale is not None else 1.0
return self.fused_softmax_func(input, mask, scale)
def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()
if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs
@staticmethod
def get_batch_per_block(sq, sk, b, np):
return scaled_masked_softmax_get_batch_per_block(sq, sk, b, np)

View File

@@ -0,0 +1,56 @@
# We use the same API as https://github.com/rwightman/pytorch-image-models/blob/v0.6.11/timm/models/layers/patch_embed.py
# But we use nn.Linear instead of Conv2d and it's about 8x faster.
from functools import partial
import torch.nn as nn
from torch import _assert
from torch.nn.modules.utils import _pair
from einops import rearrange
try:
from flash_attn.ops.fused_dense import FusedDense
except ImportError:
FusedDense = None
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
fused_bias_fc=False,
):
super().__init__()
img_size = _pair(img_size)
patch_size = _pair(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed')
linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDense
self.proj = linear_cls(in_chans * patch_size[0] * patch_size[1], embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
_, _, H, W = x.shape
_assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
x = self.proj(rearrange(x, 'b c (h p1) (w p2) -> b h w (c p1 p2)',
p1=self.patch_size[0], p2=self.patch_size[1]))
if self.flatten:
x = rearrange(x, 'b h w c -> b (h w) c')
x = self.norm(x)
return x

View File

@@ -0,0 +1,336 @@
# Copyright (c) 2023, Tri Dao.
from typing import Tuple, Optional
import math
import torch
from einops import rearrange, repeat
import rotary_emb
def rotate_half(x, interleaved=False):
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1), '... d two -> ... (d two)', two=2)
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(cos, 's d -> s 1 (2 d)')
sin = repeat(sin, 's d -> s 1 (2 d)')
return torch.cat([x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:]], dim=-1)
class ApplyRotaryEmb(torch.autograd.Function):
@staticmethod
def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
batch, seqlen, nheads, headdim = x.shape
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
x_ro = x[..., :rotary_dim]
x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2])
out = torch.empty_like(x) if not inplace else x
out_ro = out[..., :rotary_dim]
if inplace:
o1, o2 = x1, x2
else:
o1, o2 = (out_ro.chunk(2, dim=-1) if not interleaved
else (out_ro[..., ::2], out_ro[..., 1::2]))
rotary_emb.apply_rotary(x1, x2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), o1, o2, False)
if not inplace and rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
ctx.inplace = inplace
return out if not inplace else x
@staticmethod
def backward(ctx, do):
cos, sin = ctx.saved_tensors
_, seqlen, _, headdim = do.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
inplace = ctx.inplace
do_ro = do[..., :rotary_dim]
do1, do2 = (do_ro.chunk(2, dim=-1) if not ctx.interleaved
else (do_ro[..., ::2], do_ro[..., 1::2]))
dx = torch.empty_like(do) if not inplace else do
if inplace:
dx1, dx2 = do1, do2
else:
dx_ro = dx[..., :rotary_dim]
dx1, dx2 = (dx_ro.chunk(2, dim=-1) if not ctx.interleaved
else (dx_ro[..., ::2], dx_ro[..., 1::2]))
rotary_emb.apply_rotary(do1, do2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), dx1, dx2, True)
if not inplace and rotary_dim < headdim:
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
return dx, None, None, None, None
apply_rotary_emb_func = ApplyRotaryEmb.apply
class ApplyRotaryEmbQKV_(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
"""
qkv: (batch_size, seqlen, 3, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of q and k.
"""
batch, seqlen, three, nheads, headdim = qkv.shape
assert three == 3
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
cos_k = cos if cos_k is None else cos_k
sin_k = sin if sin_k is None else sin_k
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
q_ro = qkv[:, :, 0, :, :rotary_dim]
q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2])
rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)
k_ro = qkv[:, :, 1, :, :rotary_dim]
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False)
ctx.save_for_backward(cos, sin, cos_k, sin_k)
ctx.interleaved = interleaved
return qkv
@staticmethod
def backward(ctx, dqkv):
cos, sin, cos_k, sin_k = ctx.saved_tensors
_, seqlen, _, _, headdim = dqkv.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
dq_ro = dqkv[:, :, 0, :, :rotary_dim]
dq1, dq2 = (dq_ro.chunk(2, dim=-1) if not ctx.interleaved
else (dq_ro[..., ::2], dq_ro[..., 1::2]))
rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True)
dk_ro = dqkv[:, :, 1, :, :rotary_dim]
dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved
else (dk_ro[..., ::2], dk_ro[..., 1::2]))
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
rearrange(sin_k[:seqlen], 's d -> s 1 d'), dk1, dk2, True)
return dqkv, None, None, None, None, None
apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
class ApplyRotaryEmbKV_(torch.autograd.Function):
@staticmethod
def forward(ctx, kv, cos, sin, interleaved=False):
"""
kv: (batch_size, seqlen, 2, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of k.
"""
batch, seqlen, two, nheads, headdim = kv.shape
assert two == 2
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
k_ro = kv[:, :, 0, :, :rotary_dim]
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
rotary_emb.apply_rotary(k1, k2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), k1, k2,
False) # conj=False since this is the forward pass
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
return kv
@staticmethod
def backward(ctx, dkv):
cos, sin = ctx.saved_tensors
_, seqlen, _, _, headdim = dkv.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
dk_ro = dkv[:, :, 0, :, :rotary_dim]
dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved
else (dk_ro[..., ::2], dk_ro[..., 1::2]))
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), dk1, dk2,
True) # conj=True since this is the backward pass
return dkv, None, None, None
apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply
class RotaryEmbedding(torch.nn.Module):
"""
The rotary position embeddings from RoFormer_ (Su et. al).
A crucial insight from the method is that the query and keys are
transformed by rotation matrices which depend on the relative positions.
Other implementations are available in the Rotary Transformer repo_ and in
GPT-NeoX_, GPT-NeoX was an inspiration
.. _RoFormer: https://arxiv.org/abs/2104.09864
.. _repo: https://github.com/ZhuiyiTechnology/roformer
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
"""
def __init__(self, dim: int, base=10000.0, interleaved=False, scale_base=None,
pos_idx_in_fp32=True, device=None):
"""
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
otherwise they might be in lower precision.
This option was added because previously (before 2023-07-02), when we construct
the position indices, we use the dtype of self.inv_freq. In most cases this would
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
self.inv_freq would be bf16, and the position indices are also in bf16.
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
embeddings for some positions will coincide.
To maintain compatibility with models previously trained in pure bf16,
we add this option.
"""
super().__init__()
self.dim = dim
self.base = float(base)
self.pos_idx_in_fp32 = pos_idx_in_fp32
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = self._compute_inv_freq(device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.interleaved = interleaved
self.scale_base = scale_base
scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
/ (1.4 * dim) if scale_base is not None else None)
self.register_buffer("scale", scale, persistent=False)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
def _compute_inv_freq(self, device=None):
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device,
dtype=torch.float32) / self.dim))
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
# Reset the tables if the sequence length has changed,
# if we're on a new device (possibly due to tracing for instance),
# or if we're switching from inference mode to training
if (seqlen > self._seq_len_cached or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
or (self.training and self._cos_cached.is_inference())):
self._seq_len_cached = seqlen
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
if self.pos_idx_in_fp32:
t = torch.arange(seqlen, device=device, dtype=torch.float32)
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
# will be large. Having it in bf16 will lose a lot of precision and cause the
# cos & sin output to change significantly.
# We want to recompute self.inv_freq if it was not loaded in fp32
if self.inv_freq.dtype != torch.float32:
inv_freq = self._compute_inv_freq(device=device)
else:
inv_freq = self.inv_freq
else:
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
inv_freq = self.inv_freq
# Don't do einsum, it converts fp32 to fp16 under AMP
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, inv_freq)
if self.scale is None:
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
else:
power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
- seqlen // 2) / self.scale_base)
scale = self.scale.to(device=power.device) ** rearrange(power, 's -> s 1')
# We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
def forward(self, qkv: torch.Tensor, kv: Optional[torch.Tensor] = None,
seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
"""
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
else it's just q of shape (batch, seqlen, nheads, headdim)
kv: (batch, seqlen, 2, nheads, headdim)
seqlen_offset: can be used in generation where the qkv being passed in is only the last
token in the batch.
"""
seqlen = qkv.shape[1]
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
if kv is None:
if self.scale is None:
return apply_rotary_emb_qkv_(
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
None, None, self.interleaved
)
else:
return apply_rotary_emb_qkv_(
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:],
self.interleaved
)
else:
q = qkv
q = apply_rotary_emb_func(
q, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
self.interleaved, True
)
if self.scale is None:
kv = apply_rotary_emb_kv_(
kv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
self.interleaved
)
else:
kv = apply_rotary_emb_kv_(
kv, self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:],
self.interleaved
)
return q, kv

View File

@@ -0,0 +1,129 @@
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
# the losses we can get the global loss. There's no need to do it step by step
# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
import torch
import torch.nn as nn
import xentropy_cuda_lib
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 2 lines are for backward compatibility with
# older PyTorch.
if "all_gather_into_tensor" not in dir(torch.distributed):
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, labels, smoothing=0.0, ignored_index=-100, inplace_backward=False,
process_group=None):
"""
logits: (batch, vocab_size)
labels: (batch,)
If process_group is not None, we're doing Tensor Parallel: each process is responsible for
one part of the vocab. The loss needs to be aggregated across processes.
"""
batch, vocab_size = logits.shape
assert labels.shape == (batch,)
world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
ctx.total_classes = world_size * vocab_size
if world_size == 1:
losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing)
losses.masked_fill_(labels==ignored_index, 0)
labels_local = labels
else:
rank = torch.distributed.get_rank(process_group)
vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
# Create a mask of valid vocab ids (1 means it needs to be masked).
labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
ignored_mask = labels == ignored_index
labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)
# For tensor parallel cross entropy with smoothing, we want to pass in the total number
# of classes so that smoothing can be applied correctly. If total_classes=-1, use the
# last dimension of the input tensor.
losses, lse_local = xentropy_cuda_lib.forward(logits, labels_local, smoothing,
world_size * vocab_size)
assert lse_local.shape == (batch,)
assert losses.shape == (batch,)
losses.masked_fill_(ignored_mask, 0)
# For labels == ignored_index, the loss is always 0.
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
# lse_local - predicted logit, and 0 otherwise.
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
# 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes)
# For labels not in the vocab of this partition, losses contains
# 0.1 * (lse_local - sum logit / total_classes).
lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype,
device=lse_local.device)
torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(),
group=process_group)
handle_losses = torch.distributed.all_reduce(
losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
)
lse = torch.logsumexp(lse_allgather, dim=0)
# If there's no smoothing, the total losses are lse_local - predicted_logit,
# we just have to subtract the lse_local and add the lse (global).
# If there's smoothing=0.1, the total losses are
# 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)
# We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).
rank_per_sample = torch.div(labels, vocab_size, rounding_mode='floor')
lse_local = lse_allgather[rank_per_sample,
torch.arange(batch, device=lse_allgather.device)]
handle_losses.wait()
if smoothing == 0.0:
losses += lse - lse_local
else:
losses += ((1 - smoothing) * (lse - lse_local)
+ smoothing * (lse - lse_allgather.sum(dim=0)))
losses.masked_fill_(ignored_mask, 0)
ctx.save_for_backward(logits, lse, labels_local)
ctx.smoothing = smoothing
ctx.ignored_index = ignored_index
ctx.inplace_backward = inplace_backward
return losses
@staticmethod
def backward(ctx, grad_loss):
logits, lse, labels = ctx.saved_tensors
grad_loss = grad_loss.contiguous()
grad_loss.masked_fill_(labels==ctx.ignored_index, 0)
grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, lse, labels,
ctx.smoothing, ctx.inplace_backward,
ctx.total_classes)
return grad_logits, None, None, None, None, None, None
class CrossEntropyLoss(nn.Module):
def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
inplace_backward=False, process_group=None):
super().__init__()
if reduction not in ['mean', 'none']:
raise NotImplementedError("Only support reduction = 'mean' or 'none'")
self.ignore_index = ignore_index
self.reduction = reduction
self.label_smoothing = label_smoothing
self.inplace_backward = inplace_backward
self.process_group = process_group
def forward(self, input, target):
assert input.is_cuda and target.is_cuda
# SoftmaxCrossEntropyLoss implicitly casts to float
loss = SoftmaxCrossEntropyLossFn.apply(
input, target, self.label_smoothing, self.ignore_index, self.inplace_backward,
self.process_group
)
if self.reduction == 'mean':
return loss.sum() / (target != self.ignore_index).sum()
else:
return loss

View File

@@ -0,0 +1,531 @@
# Copyright (c) 2022, Tri Dao.
# This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
# https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
import re
import logging
from functools import partial
from collections.abc import Sequence
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertConfig
from transformers.models.bert.modeling_bert import BaseModelOutputWithPoolingAndCrossAttentions
from transformers.models.bert.modeling_bert import BertForPreTrainingOutput
from einops import rearrange
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp, FusedMLP
from flash_attn.modules.block import Block
from flash_attn.modules.embedding import BertEmbeddings
from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.bert_padding import index_first_axis, index_first_axis_residual
from flash_attn.utils.pretrained import state_dict_from_pretrained
try:
from flash_attn.ops.fused_dense import FusedDense
except ImportError:
FusedDense = None
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm, layer_norm
except ImportError:
dropout_add_layer_norm, layer_norm = None, None
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss
except ImportError:
CrossEntropyLoss = None
logger = logging.getLogger(__name__)
def create_mixer_cls(config, cross_attn=False, return_residual=False):
use_flash_attn = getattr(config, 'use_flash_attn', False)
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
rotary_kwargs = {}
if config.position_embedding_type == "rotary":
rotary_kwargs["rotary_emb_dim"] = getattr(config, "rotary_emb_dim", config.hidden_size)
rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
rotary_kwargs["rotary_emb_scale_base"] = getattr(config, "rotary_emb_scale_base", None)
rotary_kwargs["rotary_emb_interleaved"] = getattr(config, "rotary_emb_interleaved", False)
mixer_cls = partial(MHA, num_heads=config.num_attention_heads, cross_attn=cross_attn,
dropout=config.attention_probs_dropout_prob, causal=False,
fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn,
return_residual=return_residual, **rotary_kwargs)
return mixer_cls
def create_mlp_cls(config, layer_idx=None, return_residual=False):
inner_dim = config.intermediate_size
fused_mlp = getattr(config, 'fused_mlp', False)
if fused_mlp:
assert config.hidden_act in ['gelu_new', 'gelu_fast'], ('fused_mlp only '
'supports approximate gelu')
if not fused_mlp:
approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none'
mlp_cls = partial(Mlp, hidden_features=inner_dim,
activation=partial(F.gelu, approximate=approximate),
return_residual=return_residual)
else:
if FusedMLP is None:
raise ImportError('fused_dense is not installed')
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
if isinstance(mlp_checkpoint_lvl, Sequence):
assert layer_idx is not None
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
mlp_cls = partial(FusedMLP, hidden_features=inner_dim,
checkpoint_lvl=mlp_checkpoint_lvl, return_residual=return_residual)
return mlp_cls
def create_block(config, layer_idx=None):
last_layer_subset = getattr(config, 'last_layer_subset', False)
cross_attn=last_layer_subset and layer_idx == config.num_hidden_layers - 1
# TD [2022-12-19]: For cross attention (last layer), we actually want to return the
# residual x_kv, not residual x. But it's annoying to change the API (and it only affects
# one layer) so we just choose not to return residual in this case.
return_residual = not cross_attn
mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
prenorm=False, resid_dropout1=config.hidden_dropout_prob,
resid_dropout2=config.hidden_dropout_prob,
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
return_residual=return_residual)
return block
# https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
def _init_weights(module, initializer_range=0.02):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=initializer_range)
if module.padding_idx is not None:
nn.init.zeros_(module.weight[module.padding_idx])
class BertEncoder(nn.Module):
def __init__(self, config: BertConfig):
super().__init__()
self.use_flash_attn = getattr(config, 'use_flash_attn', False)
self.layers = nn.ModuleList([create_block(config, layer_idx=i)
for i in range(config.num_hidden_layers)])
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
"""If subset_mask is not None, we only want output for the subset of the sequence.
This means that we only compute the last layer output for these tokens.
subset_mask: (batch, seqlen), dtype=torch.bool
"""
if key_padding_mask is None or not self.use_flash_attn:
mixer_kwargs = ({'key_padding_mask': key_padding_mask}
if key_padding_mask is not None else None)
for layer in self.layers:
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
if subset_mask is not None:
hidden_states = hidden_states[subset_mask]
else:
batch, seqlen = hidden_states.shape[:2]
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
hidden_states, key_padding_mask
)
mixer_kwargs = {'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen_in_batch}
if subset_mask is None:
for layer in self.layers:
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
else:
for layer in self.layers[:-1]:
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
if key_padding_mask is not None:
subset_idx = torch.nonzero(subset_mask[key_padding_mask], as_tuple=False).flatten()
subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32)
subset_cu_seqlens = F.pad(torch.cumsum(subset_seqlens, dim=0,
dtype=torch.torch.int32), (1, 0))
else:
subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
subset_cu_seqlens = F.pad(torch.cumsum(subset_seqlens, dim=0,
dtype=torch.torch.int32), (1, 0))
hidden_states_subset, hidden_states = index_first_axis_residual(
hidden_states, subset_idx
)
# It's ok to set max_seqlen_q to be much larger
mixer_kwargs = {'x_kv': hidden_states,
'cu_seqlens': subset_cu_seqlens, 'max_seqlen': max_seqlen_in_batch,
'cu_seqlens_k': cu_seqlens, 'max_seqlen_k': max_seqlen_in_batch}
hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
return hidden_states
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed')
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
self.dense = linear_cls(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states, pool=True):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed')
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
if self.fused_dropout_add_ln and layer_norm is None:
raise ImportError('dropout_add_layer_norm is not installed')
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
self.dense = linear_cls(config.hidden_size, config.hidden_size)
approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none'
self.transform_act_fn = nn.GELU(approximate=approximate)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
if not self.fused_dropout_add_ln:
hidden_states = self.layer_norm(hidden_states)
else:
hidden_states = layer_norm(hidden_states, self.layer_norm.weight, self.layer_norm.bias,
self.layer_norm.eps)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed')
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True)
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class BertPreTrainingHeads(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, sequence_output, pooled_output):
prediction_scores = self.predictions(sequence_output)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class BertPreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super().__init__()
if not isinstance(config, BertConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
@classmethod
def from_pretrained(cls, model_name, config, *inputs, **kwargs):
"""
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPretraining instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
# Instantiate model.
model = cls(config, *inputs, **kwargs)
load_return = model.load_state_dict(remap_state_dict(state_dict_from_pretrained(model_name),
config), strict=False)
logger.info(load_return)
return model
class BertModel(BertPreTrainedModel):
def __init__(self, config: BertConfig, add_pooling_layer=True):
super().__init__(config)
self.pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
if config.vocab_size % self.pad_vocab_size_multiple != 0:
config.vocab_size += (self.pad_vocab_size_multiple
- (config.vocab_size % self.pad_vocab_size_multiple))
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
if self.fused_dropout_add_ln and layer_norm is None:
raise ImportError('dropout_add_layer_norm is not installed')
assert config.hidden_act in ['gelu', 'gelu_new', 'gelu_fast']
self.embeddings = BertEmbeddings(config.hidden_size, config.vocab_size,
config.max_position_embeddings, config.type_vocab_size,
padding_idx=config.pad_token_id)
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None,
masked_tokens_mask=None):
"""If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
we only want the output for the masked tokens. This means that we only compute the last
layer output for these tokens.
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
"""
hidden_states = self.embeddings(input_ids, position_ids=position_ids,
token_type_ids=token_type_ids)
# TD [2022-12:18]: Don't need to force residual in fp32
# BERT puts embedding LayerNorm before embedding dropout.
if not self.fused_dropout_add_ln:
hidden_states = self.emb_ln(hidden_states)
else:
hidden_states = layer_norm(hidden_states, self.emb_ln.weight, self.emb_ln.bias,
self.emb_ln.eps)
hidden_states = self.emb_drop(hidden_states)
if masked_tokens_mask is not None:
batch_size, seqlen = input_ids.shape[:2]
# We also need the first column for the CLS token
first_col_mask = torch.zeros(batch_size, seqlen, dtype=torch.bool,
device=input_ids.device)
first_col_mask[:, 0] = True
subset_mask = masked_tokens_mask | first_col_mask
else:
subset_mask = None
sequence_output = self.encoder(hidden_states, key_padding_mask=attention_mask,
subset_mask=subset_mask)
if masked_tokens_mask is None:
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
else:
# TD [2022-03-01]: the indexing here is very tricky.
if attention_mask is not None:
subset_idx = subset_mask[attention_mask]
pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
sequence_output = sequence_output[masked_tokens_mask[attention_mask][subset_idx]]
else:
pool_input = sequence_output[first_col_mask[subset_mask]]
sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
pooled_output = (self.pooler(pool_input, pool=False)
if self.pooler is not None else None)
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
)
class BertForPreTraining(BertPreTrainedModel):
def __init__(self, config: BertConfig):
super().__init__(config)
# If dense_seq_output, we only need to pass the hidden states for the masked out tokens
# (around 15%) to the classifier heads.
self.dense_seq_output = getattr(config, 'dense_seq_output', False)
# If last_layer_subset, we only need the compute the last layer for a subset of tokens
# (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
self.last_layer_subset = getattr(config, 'last_layer_subset', False)
if self.last_layer_subset:
assert self.dense_seq_output, 'last_layer_subset requires dense_seq_output'
use_xentropy = getattr(config, 'use_xentropy', False)
if use_xentropy and CrossEntropyLoss is None:
raise ImportError('xentropy_cuda is not installed')
loss_cls = (nn.CrossEntropyLoss if not use_xentropy
else partial(CrossEntropyLoss, inplace_backward=True))
self.bert = BertModel(config)
self.cls = BertPreTrainingHeads(config)
self.mlm_loss = loss_cls(ignore_index=0)
self.nsp_loss = loss_cls(ignore_index=-1)
# Initialize weights and apply final processing
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
self.tie_weights()
def tie_weights(self):
self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None,
labels=None, next_sentence_label=None):
"""
If labels are provided, they must be 0 for masked out tokens (as specified in the attention
mask).
Outputs:
if `labels` and `next_sentence_label` are not `None`:
Outputs the total_loss which is the sum of the masked language modeling loss and the next
sentence classification loss.
if `labels` or `next_sentence_label` is `None`:
Outputs a tuple comprising
- the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
- the next sentence classification logits of shape [batch_size, 2].
"""
masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
outputs = self.bert(
input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask.bool() if attention_mask is not None else None,
masked_tokens_mask=masked_tokens_mask
)
sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
if self.dense_seq_output and labels is not None:
masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
if not self.last_layer_subset:
sequence_output = index_first_axis(rearrange(sequence_output, 'b s d -> (b s) d'),
masked_token_idx)
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
total_loss = None
if labels is not None and next_sentence_label is not None:
if self.dense_seq_output and labels is not None: # prediction_scores are already flattened
masked_lm_loss = self.mlm_loss(prediction_scores,
labels.flatten()[masked_token_idx])
else:
masked_lm_loss = self.mlm_loss(rearrange(prediction_scores, '... v -> (...) v'),
rearrange(labels, '... -> (...)'))
next_sentence_loss = self.nsp_loss(rearrange(seq_relationship_score, '... t -> (...) t'),
rearrange(next_sentence_label, '... -> (...)'))
total_loss = masked_lm_loss.float() + next_sentence_loss.float()
return BertForPreTrainingOutput(
loss=total_loss,
prediction_logits=prediction_scores,
seq_relationship_logits=seq_relationship_score,
)
def remap_state_dict(state_dict, config):
# LayerNorm
def key_mapping_ln_gamma_beta(key):
key = re.sub(r'LayerNorm.gamma$', 'LayerNorm.weight', key)
key = re.sub(r'LayerNorm.beta$', 'LayerNorm.bias', key)
return key
state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
# Layers
def key_mapping_layers(key):
return re.sub(r'^bert.encoder.layer.', 'bert.encoder.layers.', key)
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^bert.embeddings.LayerNorm.', 'bert.emb_ln.', key)
key = re.sub(r'^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)',
r'bert.encoder.layers.\1.norm1.\2', key)
key = re.sub(r'^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)',
r'bert.encoder.layers.\1.norm2.\2', key)
key = re.sub(r'^cls.predictions.transform.LayerNorm.(weight|bias)',
r'cls.predictions.transform.layer_norm.\1', key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
def key_mapping_mlp(key):
key = re.sub(r'^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)',
r'bert.encoder.layers.\1.mlp.fc1.\2', key)
key = re.sub(r'^bert.encoder.layers.(\d+).output.dense.(weight|bias)',
r'bert.encoder.layers.\1.mlp.fc2.\2', key)
return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
last_layer_subset = getattr(config, 'last_layer_subset', False)
for d in range(config.num_hidden_layers):
Wq = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.query.weight')
Wk = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.key.weight')
Wv = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.value.weight')
bq = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.query.bias')
bk = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.key.bias')
bv = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.value.bias')
if not (last_layer_subset and d == config.num_hidden_layers - 1):
state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.weight'] = torch.cat(
[Wq, Wk, Wv], dim=0
)
state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.bias'] = torch.cat([bq, bk, bv], dim=0)
else:
state_dict[f'bert.encoder.layers.{d}.mixer.Wq.weight'] = Wq
state_dict[f'bert.encoder.layers.{d}.mixer.Wkv.weight'] = torch.cat(
[Wk, Wv], dim=0
)
state_dict[f'bert.encoder.layers.{d}.mixer.Wq.bias'] = bq
state_dict[f'bert.encoder.layers.{d}.mixer.Wkv.bias'] = torch.cat([bk, bv], dim=0)
def key_mapping_attn(key):
return re.sub(r'^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)',
r'bert.encoder.layers.\1.mixer.out_proj.\2', key)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
def key_mapping_decoder_bias(key):
return re.sub(r'^cls.predictions.bias', 'cls.predictions.decoder.bias', key)
state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
# Word embedding
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
if pad_vocab_size_multiple > 1:
word_embeddings = state_dict['bert.embeddings.word_embeddings.weight']
state_dict['bert.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
)
decoder_weight = state_dict['cls.predictions.decoder.weight']
state_dict['cls.predictions.decoder.weight'] = F.pad(
decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
)
# If the vocab was padded, we want to set the decoder bias for those padded indices to be
# strongly negative (i.e. the decoder shouldn't predict those indices).
# TD [2022-05-09]: I don't think it affects the MLPerf training.
decoder_bias = state_dict['cls.predictions.decoder.bias']
state_dict['cls.predictions.decoder.bias'] = F.pad(
decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
)
return state_dict

View File

@@ -0,0 +1,122 @@
# Copyright (c) 2023, Tri Dao.
import math
import re
from collections import OrderedDict
import torch
import torch.nn.functional as F
from einops import rearrange
from transformers import GPT2Config, FalconConfig
def remap_state_dict_hf_falcon(state_dict, config):
def key_mapping_layers(key):
return re.sub(r'^transformer.h.', 'transformer.layers.', key)
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding
def key_mapping_emb(key):
return re.sub(r'^transformer.word_embeddings.', 'transformer.embeddings.word_embeddings.', key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
else:
output_embeddings = state_dict.pop('lm_head.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
output_embeddings_bias = state_dict.pop('lm_head.bias')
state_dict['lm_head.bias'] = F.pad(
output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])
)
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^transformer.layers.(\d+).input_layernorm.',
r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.',
r'transformer.layers.\1.norm2.', key)
key = re.sub(r'^transformer.layers.(\d+).ln_attn.', r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).ln_mlp.', r'transformer.layers.\1.norm2.', key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
def key_mapping_mlp(key):
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.',
r'transformer.layers.\1.mlp.fc1.', key)
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.',
r'transformer.layers.\1.mlp.fc2.', key)
return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
def key_mapping_attn(key):
key = re.sub(r'^transformer.layers.(\d+).self_attention.query_key_value.',
r'transformer.layers.\1.mixer.Wqkv.', key)
key = re.sub(r'^transformer.layers.(\d+).self_attention.dense.',
r'transformer.layers.\1.mixer.out_proj.', key)
return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
n_head = config.n_head
n_head_kv = getattr(config, "n_head_kv", 1)
headdim = config.hidden_size // n_head
for l in range(config.n_layer):
# The weights are stored in a different layout compared to our implementation
Wqkv = rearrange(state_dict.pop(f'transformer.layers.{l}.mixer.Wqkv.weight'),
"(group ratio headdim) ... -> group ratio headdim ...",
ratio=n_head // n_head_kv + 2, headdim=headdim)
Wq = rearrange(Wqkv[:, :-2], "group ratio headdim ... -> (group ratio headdim) ...")
Wk = rearrange(Wqkv[:, [-2]], "group ratio headdim ... -> (group ratio headdim) ...")
Wv = rearrange(Wqkv[:, [-1]], "group ratio headdim ... -> (group ratio headdim) ...")
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
return state_dict
def falcon_config_to_gpt2_config(falcon_config: FalconConfig) -> GPT2Config:
# The 40b config uses "n_head_kv" instead of "num_kv_heads"
n_head_kv = getattr(falcon_config, "n_head_kv",
1 if getattr(falcon_config, "multi_query", False)
else falcon_config.n_head)
# HACK: the 40b config has 2 LN per layer instead of 1, but that's not reflected in the config.
# So we have to infer it from the number of heads in the key/value block
parallel_block_tied_norm = n_head_kv == 1
return GPT2Config(
vocab_size=falcon_config.vocab_size,
n_positions=0, # No absolute position embedding
n_embd=falcon_config.hidden_size,
n_layer=falcon_config.n_layer,
n_head=falcon_config.n_head,
n_inner=falcon_config.hidden_size * 4,
activation_function="gelu",
resid_pdrop=falcon_config.hidden_dropout,
embd_pdrop=0.0, # There doesn't seem to be any embedding dropout
attn_pdrop=falcon_config.attention_dropout,
layer_norm_epsilon=falcon_config.layer_norm_epsilon,
initializer_range=falcon_config.initializer_range,
bos_token_id=falcon_config.bos_token_id,
eos_token_id=falcon_config.eos_token_id,
# These are new arguments not in the original GPT2Config
parallel_block=falcon_config.parallel_attn,
n_head_kv=n_head_kv,
parallel_block_tied_norm=parallel_block_tied_norm,
rotary_emb_fraction=1.0,
rotary_emb_interleaved=False,
tie_word_embeddings=True,
qkv_proj_bias=falcon_config.bias,
out_proj_bias=falcon_config.bias,
mlp_fc1_bias=falcon_config.bias,
mlp_fc2_bias=falcon_config.bias,
lm_head_bias=False,
)

View File

@@ -0,0 +1,740 @@
# Copyright (c) 2023, Tri Dao.
import logging
import math
import re
from functools import partial
from collections import namedtuple, OrderedDict
from collections.abc import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Config
from einops import rearrange
from flash_attn.ops.activations import sqrelu_fwd
from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.modules.mlp import Mlp, GatedMlp, ParallelMLP, FusedMLP, ParallelFusedMLP
from flash_attn.modules.block import Block, ParallelBlock
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import GenerationMixin
from flash_attn.models.opt import remap_state_dict_hf_opt
from flash_attn.models.gptj import remap_state_dict_hf_gptj
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox
from flash_attn.models.falcon import remap_state_dict_hf_falcon
try:
from flash_attn.ops.fused_dense import ColumnParallelLinear
except ImportError:
ColumnParallelLinear = None
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
dropout_add_layer_norm = None
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
except ImportError:
dropout_add_layer_norm_parallel_residual = None
try:
from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm
except ImportError:
RMSNorm, dropout_add_rms_norm = None, None
try:
from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
except ImportError:
dropout_add_rms_norm_parallel_residual = None
try:
from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
except ImportError:
FusedDenseSqreluDense = None
logger = logging.getLogger(__name__)
def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads)
softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5)
if config.scale_attn_by_inverse_layer_idx:
assert layer_idx is not None
softmax_scale /= float(layer_idx + 1)
dwconv = getattr(config, 'attn_dwconv', False)
if dwconv:
assert process_group is None, 'TensorParallel MHA does not support dwconv yet'
qkv_proj_bias = getattr(config, 'qkv_proj_bias', True)
out_proj_bias = getattr(config, 'out_proj_bias', True)
rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim)
rotary_emb_base = getattr(config, 'rotary_emb_base', 10000.0)
rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', None)
rotary_emb_interleaved = getattr(config, 'rotary_emb_interleaved', False)
use_flash_attn = getattr(config, 'use_flash_attn', False)
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
if not fused_bias_fc:
assert process_group is None, 'TensorParallel MHA requires fused_bias_fc'
mha_cls = MHA if process_group is None else ParallelMHA
serial_kwargs = ({'fused_bias_fc': fused_bias_fc, 'dwconv': dwconv}
if process_group is None else {})
parallel_kwargs = ({'process_group': process_group,
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
if process_group is not None else {})
num_heads_kv = getattr(config, "n_head_kv", None)
mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads,
num_heads_kv=num_heads_kv,
qkv_proj_bias=qkv_proj_bias, out_proj_bias=out_proj_bias,
dropout=config.attn_pdrop,
softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx,
rotary_emb_dim=rotary_emb_dim, rotary_emb_base=rotary_emb_base,
rotary_emb_scale_base=rotary_emb_scale_base,
rotary_emb_interleaved=rotary_emb_interleaved,
use_flash_attn=use_flash_attn,
**serial_kwargs, **parallel_kwargs, **factory_kwargs)
return mixer_cls
def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
mlp_fc1_bias = getattr(config, 'mlp_fc1_bias', True)
mlp_fc2_bias = getattr(config, 'mlp_fc2_bias', True)
fused_mlp = getattr(config, 'fused_mlp', False)
if fused_mlp:
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu']
fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False)
if fused_dense_sqrelu_dense:
assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only '
'supports approximate activation_function sqrelu')
assert not (fused_dense_sqrelu_dense and fused_mlp)
if not fused_mlp and not fused_dense_sqrelu_dense:
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx', 'relu',
'sqrelu', 'glu', 'swiglu', 'geglu']
if config.activation_function in ['glu', 'swiglu', 'geglu']:
activation = (F.sigmoid if config.activation_function == 'glu'
else (F.silu if config.activation_function == 'swiglu'
else F.gelu))
mlp_cls = partial(GatedMlp, hidden_features=config.n_inner, activation=activation,
bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, **factory_kwargs)
else:
if config.activation_function == 'relu':
activation = partial(F.relu, inplace=True)
elif config.activation_function == 'sqrelu':
activation = sqrelu_fwd
else:
approximate = ('tanh' if config.activation_function
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none')
activation=partial(F.gelu, approximate=approximate)
mlp_cls = Mlp if process_group is None else ParallelMLP
parallel_kwargs = ({'process_group': process_group,
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
if process_group is not None else {})
mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation,
bias1=mlp_fc1_bias, bias2=mlp_fc2_bias,
**parallel_kwargs, **factory_kwargs)
else:
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
if isinstance(mlp_checkpoint_lvl, Sequence):
assert layer_idx is not None
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
if fused_mlp:
if FusedMLP is None:
raise ImportError('fused_dense is not installed')
activation = ('gelu_approx' if config.activation_function
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else config.activation_function)
mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
parallel_kwargs = ({'process_group': process_group,
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
if process_group is not None else {})
mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation,
checkpoint_lvl=mlp_checkpoint_lvl,
bias1=mlp_fc1_bias, bias2=mlp_fc2_bias,
**parallel_kwargs, **factory_kwargs)
elif fused_dense_sqrelu_dense:
assert FusedDenseSqreluDense is not None
mlp_cls = partial(FusedDenseSqreluDense, hidden_features=config.n_inner,
checkpoint_lvl=mlp_checkpoint_lvl, **factory_kwargs)
else:
raise RuntimeError('MLP type not supported')
return mlp_cls
def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
sequence_parallel = getattr(config, 'sequence_parallel', True)
mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
use_rms_norm = getattr(config, 'rms_norm', False)
norm_cls = partial(nn.LayerNorm if not use_rms_norm else RMSNorm,
eps=config.layer_norm_epsilon, **factory_kwargs)
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
residual_in_fp32 = getattr(config, 'residual_in_fp32', False)
resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop
prenorm = getattr(config, 'prenorm', True)
parallel_block = getattr(config, 'parallel_block', False)
if not parallel_block:
block = Block(
config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
prenorm=prenorm, resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop,
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
residual_in_fp32=residual_in_fp32,
sequence_parallel=sequence_parallel and process_group is not None,
mark_shared_params=process_group is not None
)
else:
assert prenorm
block = ParallelBlock(
config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop,
tied_norm=getattr(config, 'parallel_block_tied_norm', False),
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
residual_in_fp32=residual_in_fp32,
sequence_parallel=sequence_parallel and process_group is not None,
mark_shared_params=process_group is not None
)
block.layer_idx = layer_idx
return block
class GPTPreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super().__init__()
if not isinstance(config, GPT2Config):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
@classmethod
def from_pretrained(cls, model_name, config, *args, strict=True, device=None, dtype=None,
world_size=1, rank=0, **kwargs):
"""
Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
"""
# Instantiate model.
model = cls(config, *args, device=device, dtype=dtype, **kwargs)
# Load state_dict in cpu because we already initialized the model in GPU, and we don't
# want extra stuff taking up more GPU memory
state_dict = state_dict_from_pretrained(
model_name, device='cpu', dtype=dtype
)
if model_name.startswith('gpt2'):
state_dict = remap_state_dict_hf_gpt2(state_dict, config)
elif model_name.startswith('facebook/opt'):
state_dict = remap_state_dict_hf_opt(state_dict, config)
elif model_name.startswith('EleutherAI/gpt-j-'):
state_dict = remap_state_dict_hf_gptj(state_dict, config)
elif model_name.startswith('EleutherAI/gpt-neox-'):
state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)
elif model_name.startswith('tiiuae/falcon-'):
state_dict = remap_state_dict_hf_falcon(state_dict, config)
else:
raise NotImplementedError(f'Model {model_name} not supported')
if world_size > 1:
state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
load_return = model.load_state_dict(state_dict, strict=strict)
logger.info(load_return)
return model
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=initializer_range)
if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["out_proj.weight", "fc2.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))
class GPTModel(GPTPreTrainedModel):
def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
super().__init__(config)
factory_kwargs = {'device': device, 'dtype': dtype}
self.process_group = process_group
self.sequence_parallel = getattr(config, 'sequence_parallel', True)
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx',
'relu', 'sqrelu', 'glu', 'swiglu', 'geglu']
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
self.residual_in_fp32 = getattr(config, 'residual_in_fp32', False)
# These 2 options are for OPT-350m
self.prenorm = getattr(config, 'prenorm', True)
use_rms_norm = getattr(config, 'rms_norm', False)
word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None)
# For GPT-J, GPT-NeoX
self.parallel_block = getattr(config, 'parallel_block', False)
if process_group is None:
self.embeddings = GPT2Embeddings(
config.hidden_size, vocab_size, config.max_position_embeddings,
word_embed_proj_dim=word_embed_proj_dim, **factory_kwargs
)
else:
self.embeddings = ParallelGPT2Embeddings(
config.hidden_size, vocab_size, config.max_position_embeddings,
process_group=process_group, sequence_parallel=self.sequence_parallel,
**factory_kwargs
)
# We change the order of dropout, residual and layer norm:
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
# Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
# the main branch (output of MLP). The model definition is unchanged, but the mapping of the
# nn.Dropout probabilities are changed.
# This is for performance reason: we can fuse dropout + add + layer_norm.
self.layers = nn.ModuleList([create_block(config, layer_idx=i, process_group=process_group,
**factory_kwargs)
for i in range(config.num_hidden_layers)])
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
if self.fused_dropout_add_ln:
if ((not self.parallel_block and dropout_add_layer_norm is None)
or (self.parallel_block and dropout_add_layer_norm_parallel_residual is None)):
raise ImportError('dropout_layer_norm is not installed')
if self.prenorm:
self.drop_f = nn.Dropout(config.resid_pdrop)
norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm
self.ln_f = norm_cls(config.hidden_size, eps=config.layer_norm_epsilon,
**factory_kwargs)
if process_group is not None:
for p in self.ln_f.parameters():
# Mark the norm parameters as "shared_params" so that we sync their values at init.
p._shared_params = True
# Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
if self.sequence_parallel:
p._sequence_parallel = True
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
initializer_range=config.initializer_range))
self.tie_weights()
def tie_weights(self):
if self.process_group is not None:
sync_shared_params(self, self.process_group)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return {i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
for i, layer in enumerate(self.layers)}
def forward(self, input_ids, position_ids=None, inference_params=None):
# If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
# dimensions so that we can split on it easily, in case of small batch size.
# Only the attention layers need to know the seqlen.
embedding_kwargs = ({'combine_batch_seqlen_dim': True}
if self.process_group is not None and self.sequence_parallel else {})
hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
if self.parallel_block:
hidden_states2 = None
residual = None
mixer_kwargs = ({'seqlen': input_ids.shape[1]}
if self.process_group is not None and self.sequence_parallel else {})
if inference_params is not None:
mixer_kwargs['inference_params'] = inference_params
for layer in self.layers:
if self.prenorm:
if not self.parallel_block:
hidden_states, residual = layer(hidden_states, residual,
mixer_kwargs=mixer_kwargs)
else:
hidden_states, hidden_states2, residual = layer(
hidden_states, hidden_states2, residual, mixer_kwargs=mixer_kwargs
)
else:
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
if self.prenorm:
if not self.fused_dropout_add_ln:
dropped = self.drop_f(hidden_states)
if not self.parallel_block:
residual = (dropped + residual) if residual is not None else dropped
else:
dropped2 = self.drop_f(hidden_states2)
residual = ((residual + dropped + dropped2)
if residual is not None else dropped + dropped2)
hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
else:
# Set prenorm=False here since we don't need the residual
if not self.parallel_block:
fused_add_norm_fn = (dropout_add_rms_norm if isinstance(self.ln_f, RMSNorm)
else dropout_add_layer_norm)
hidden_states = fused_add_norm_fn(
hidden_states, residual, self.ln_f.weight, self.ln_f.bias,
self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False,
residual_in_fp32=self.residual_in_fp32
)
else:
fused_add_norm_fn = (dropout_add_rms_norm_parallel_residual
if isinstance(self.ln_f, RMSNorm)
else dropout_add_layer_norm_parallel_residual)
hidden_states, _ = fused_add_norm_fn(
hidden_states, hidden_states2, residual, self.ln_f.weight, self.ln_f.bias,
None, None, self.drop_f.p if self.training else 0.0, self.ln_f.eps,
prenorm=False, residual_in_fp32=self.residual_in_fp32
)
return hidden_states
class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(config)
self.process_group = process_group
self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
self.tie_word_embeddings = getattr(config, 'tie_word_embeddings', True)
lm_head_bias = getattr(config, 'lm_head_bias', False)
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
# This option is for OPT-350m
word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None)
embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim
if word_embed_proj_dim is not None:
self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs)
else:
self.project_out = None
if process_group is None:
self.lm_head = nn.Linear(embed_dim, vocab_size, bias=lm_head_bias, **factory_kwargs)
else:
if ColumnParallelLinear is None:
raise ImportError('fused_dense_lib is not installed')
self.lm_head = ColumnParallelLinear(
embed_dim, vocab_size, process_group, bias=lm_head_bias,
sequence_parallel=getattr(config, 'sequence_parallel', True), **factory_kwargs
)
# Initialize weights and apply final processing
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
initializer_range=config.initializer_range))
self.tie_weights()
def tie_weights(self):
if self.tie_word_embeddings:
self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
if self.process_group is not None:
sync_shared_params(self, self.process_group)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.transformer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype,
**kwargs)
def forward(self, input_ids, position_ids=None, inference_params=None, last_token_only=False):
"""
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
last_token_only: whether to return the logit for the last token only,
of shape (batch_size, vocab_size)
"""
hidden_states = self.transformer(input_ids, position_ids=position_ids,
inference_params=inference_params)
if last_token_only:
hidden_states = hidden_states[:, -1]
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)
lm_logits = self.lm_head(hidden_states)
# During inference, we want the full logit for sampling
if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
lm_logits = rearrange(lm_logits, '(n b) ... d -> b ... (n d)', b=hidden_states.shape[0])
CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
return CausalLMOutput(logits=lm_logits)
def load_state_dict(self, state_dict, strict=True):
# Remapping from our checkpoints that used a different ordering of layers in the block
# Previous: Attn / MLP -> Dropout -> Add -> LN
# Current: Dropout -> Add -> LN -> Attn / MLP
if 'transformer.ln_0.weight' in state_dict:
n_layers = len(self.transformer.layers)
ln_weight = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.weight')
ln_bias = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.bias')
state_dict['transformer.ln_f.weight'] = ln_weight
state_dict['transformer.ln_f.bias'] = ln_bias
for l in reversed(range(n_layers)):
ln_weight = state_dict.pop(f'transformer.layers.{l}.norm1.weight')
ln_bias = state_dict.pop(f'transformer.layers.{l}.norm1.bias')
state_dict[f'transformer.layers.{l}.norm2.weight'] = ln_weight
state_dict[f'transformer.layers.{l}.norm2.bias'] = ln_bias
if l > 0:
ln_weight = state_dict.pop(f'transformer.layers.{l - 1}.norm2.weight')
ln_bias = state_dict.pop(f'transformer.layers.{l - 1}.norm2.bias')
state_dict[f'transformer.layers.{l}.norm1.weight'] = ln_weight
state_dict[f'transformer.layers.{l}.norm1.bias'] = ln_bias
ln_weight = state_dict.pop('transformer.ln_0.weight')
ln_bias = state_dict.pop('transformer.ln_0.bias')
state_dict[f'transformer.layers.0.norm1.weight'] = ln_weight
state_dict[f'transformer.layers.0.norm1.bias'] = ln_bias
return super().load_state_dict(state_dict, strict=strict)
def shard_state_dict_tp(state_dict, config, world_size, rank):
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
with tensor parallel.
"""
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
assert vocab_size % world_size == 0
assert config.hidden_size % world_size == 0
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
assert inner_dim % world_size == 0
def shard_first_dim(state_dict, key):
if key in state_dict:
x = state_dict[key]
dim = x.shape[0] // world_size
state_dict[key] = x[rank * dim:(rank + 1) * dim]
def shard_last_dim(state_dict, key):
if key in state_dict:
x = state_dict[key]
dim = x.shape[-1] // world_size
state_dict[key] = x[..., rank * dim:(rank + 1) * dim]
def shard_qkv_headdim(state_dict, key):
if key in state_dict:
n_head = config.n_head
n_head_kv = getattr(config, 'n_head_kv', n_head)
assert n_head % world_size == 0 and n_head_kv % world_size == 0
if n_head_kv == n_head:
x = rearrange(state_dict[key], '(three d) ... -> three d ...', three=3)
dim = x.shape[1] // world_size
state_dict[key] = rearrange(x[:, rank * dim:(rank + 1) * dim],
'three d ... -> (three d) ...')
else:
n_head_per_rank = n_head // world_size
n_head_kv_per_rank = n_head_kv // world_size
x = rearrange(state_dict[key], '(nheadqkv headdim) ... -> nheadqkv headdim ...',
nheadqkv=n_head + 2 * n_head_kv)
state_dict[key] = rearrange(torch.cat([
x[rank * n_head_per_rank:(rank + 1) * n_head_per_rank],
x[n_head + rank * n_head_kv_per_rank:n_head + (rank + 1) * n_head_kv_per_rank],
x[n_head + n_head_kv + rank * n_head_kv_per_rank:n_head + n_head_kv + (rank + 1) * n_head_kv_per_rank],
], dim=0), "nheadqkv headdim ... -> (nheadqkv headdim) ...")
shard_first_dim(state_dict, 'transformer.embeddings.word_embeddings.weight')
if 'lm_head.weight' in state_dict:
shard_first_dim(state_dict, 'lm_head.weight')
if 'transformer.embeddings.position_embeddings.weight' in state_dict:
shard_last_dim(state_dict, 'transformer.embeddings.position_embeddings.weight')
for i in range(config.num_hidden_layers):
shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight')
shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias')
shard_last_dim(state_dict, f'transformer.layers.{i}.mixer.out_proj.weight')
if rank != 0:
state_dict.pop(f'transformer.layers.{i}.mixer.out_proj.bias', None)
shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight')
shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias')
shard_last_dim(state_dict, f'transformer.layers.{i}.mlp.fc2.weight')
if rank != 0:
state_dict.pop(f'transformer.layers.{i}.mlp.fc2.bias', None)
return state_dict
def combine_state_dicts_tp(state_dicts, config):
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
with tensor parallel.
"""
world_size = len(state_dicts)
keys = state_dicts[0].keys()
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
assert vocab_size % world_size == 0
assert config.hidden_size % world_size == 0
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
assert inner_dim % world_size == 0
# Sometimes the word embeddings are sharded on the 0th dim, sometimes on the 1st dim.
# vocab_size // world_size coordinates are nonzero.
def combine_word_embeddings(state_dicts, state_dict, key):
dim = 0 if state_dicts[0][key].shape[0] == vocab_size // world_size else 1
state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
def combine_dim(state_dicts, state_dict, key, dim=-1):
if key in state_dict:
state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
def combine_qkv_headdim(state_dicts, state_dict, key):
n_head = config.n_head
n_head_kv = getattr(config, 'n_head_kv', n_head)
assert n_head % world_size == 0 and n_head_kv % world_size == 0
n_head_per_rank = n_head // world_size
n_head_kv_per_rank = n_head_kv // world_size
if key in state_dict:
if n_head_kv == n_head:
xs = [rearrange(s[key], '(three d) ... -> three d ...', three=3) for s in state_dicts]
state_dict[key] = rearrange(torch.cat(xs, dim=1), 'three d ... -> (three d) ...')
else:
xs = [rearrange(s[key], '(nheadqkv headdim) ... -> nheadqkv headdim ...',
nheadqkv=n_head + 2 * n_head_kv) for s in state_dicts]
state_dict[key] = rearrange(torch.cat([
torch.cat([x[:n_head_per_rank] for x in xs], dim=0),
torch.cat([x[n_head_per_rank:n_head_per_rank + n_head_kv_per_rank] for x in xs], dim=0),
torch.cat([x[-n_head_kv_per_rank:] for x in xs], dim=0),
], dim=0), "nheadqkv headdim ... -> (nheadqkv headdim) ...")
def combine_gated_mlp(state_dicts, state_dict, key):
if key in state_dict:
xs = [rearrange(s[key], '(two d) ... -> two d ...', two=2) for s in state_dicts]
state_dict[key] = rearrange(torch.cat(xs, dim=1), 'two d ... -> (two d) ...')
state_dict = state_dicts[0].copy() # don't modify state_dict[0] inplace
combine_word_embeddings(state_dicts, state_dict, 'transformer.embeddings.word_embeddings.weight')
if 'lm_head.weight' in state_dict:
combine_word_embeddings(state_dicts, state_dict, 'lm_head.weight')
if 'transformer.embeddings.position_embeddings.weight' in state_dict:
combine_dim(state_dicts, state_dict, 'transformer.embeddings.position_embeddings.weight', -1)
mlp_combine_fn = (combine_gated_mlp if config.activation_function in ['glu', 'swiglu', 'geglu']
else partial(combine_dim, dim=0))
for i in range(config.num_hidden_layers):
combine_qkv_headdim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight')
combine_qkv_headdim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias')
combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.out_proj.weight', -1)
mlp_combine_fn(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc1.weight')
combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc1.bias', 0)
combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc2.weight', -1)
return state_dict
def remap_state_dict_hf_gpt2(state_dict, config):
# Word embedding and position embedding
def key_mapping_pos_emb(key):
return re.sub(r'^wpe.', 'transformer.embeddings.position_embeddings.', key)
state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('wte.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^ln_f.(weight|bias)', r'transformer.ln_f.\1', key)
key = re.sub(r'^h.(\d+).ln_(1|2).(weight|bias)', r'transformer.layers.\1.norm\2.\3', key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
for d in range(config.num_hidden_layers):
W1 = state_dict.pop(f'h.{d}.mlp.c_fc.weight')
state_dict[f'transformer.layers.{d}.mlp.fc1.weight'] = W1.t()
W2 = state_dict.pop(f'h.{d}.mlp.c_proj.weight')
state_dict[f'transformer.layers.{d}.mlp.fc2.weight'] = W2.t()
def key_mapping_mlp(key):
key = re.sub(r'^h.(\d+).mlp.c_fc.bias', r'transformer.layers.\1.mlp.fc1.bias', key)
key = re.sub(r'^h.(\d+).mlp.c_proj.bias', r'transformer.layers.\1.mlp.fc2.bias', key)
return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
for d in range(config.num_hidden_layers):
state_dict.pop(f'h.{d}.attn.bias') # We don't store this bias
Wqkv = state_dict.pop(f'h.{d}.attn.c_attn.weight')
state_dict[f'transformer.layers.{d}.mixer.Wqkv.weight'] = Wqkv.t()
Wout = state_dict.pop(f'h.{d}.attn.c_proj.weight')
state_dict[f'transformer.layers.{d}.mixer.out_proj.weight'] = Wout.t()
def key_mapping_attn(key):
key = re.sub(r'^h.(\d+).attn.c_attn.bias', r'transformer.layers.\1.mixer.Wqkv.bias', key)
key = re.sub(r'^h.(\d+).attn.c_proj.bias', r'transformer.layers.\1.mixer.out_proj.bias', key)
return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
def remap_state_dict_megatron(state_dict, config):
def key_mapping_transformer(key):
key = re.sub(r'^language_model.encoder.', 'transformer.', key)
key = re.sub(r'^language_model.', 'transformer.', key)
return key
state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items())
# Word embedding and position embedding
def key_mapping_pos_emb(key):
return re.sub(r'^wpe.', 'transformer.embeddings.position_embeddings.', key)
state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embedding.word_embeddings.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^transformer.final_layernorm.(weight|bias)', r'transformer.ln_f.\1', key)
key = re.sub(r'^transformer.layers.(\d+).input_layernorm.(weight|bias)',
r'transformer.layers.\1.norm1.\2', key)
key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.(weight|bias)',
r'transformer.layers.\1.norm2.\2', key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
def key_mapping_mlp(key):
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.(weight|bias)',
r'transformer.layers.\1.mlp.fc1.\2', key)
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.(weight|bias)',
r'transformer.layers.\1.mlp.fc2.\2', key)
return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
def key_mapping_attn(key):
key = re.sub(r'^transformer.layers.(\d+).self_attention.rotary_emb.inv_freq',
r'transformer.layers.\1.mixer.rotary_emb.inv_freq', key)
key = re.sub(r'^transformer.layers.(\d+).self_attention.query_key_value.(weight|bias)',
r'transformer.layers.\1.mixer.Wqkv.\2', key)
key = re.sub(r'^transformer.layers.(\d+).self_attention.dense.(weight|bias)',
r'transformer.layers.\1.mixer.out_proj.\2', key)
return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
# Megatron stores Wqkv as ((nheads 3 headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
headdim = config.hidden_size // config.num_attention_heads
for d in range(config.num_hidden_layers):
Wqkv = state_dict.pop(f'transformer.layers.{d}.mixer.Wqkv.weight')
state_dict[f'transformer.layers.{d}.mixer.Wqkv.weight'] = rearrange(
Wqkv, '(nheads three headdim) ... -> (three nheads headdim) ...',
three=3, headdim=headdim
)
bqkv = state_dict.pop(f'transformer.layers.{d}.mixer.Wqkv.bias')
state_dict[f'transformer.layers.{d}.mixer.Wqkv.bias'] = rearrange(
bqkv, '(nheads three headdim) -> (three nheads headdim)',
three=3, headdim=headdim
)
return state_dict

View File

@@ -0,0 +1,107 @@
# Copyright (c) 2023, Tri Dao.
import math
import re
from collections import OrderedDict
import torch
import torch.nn.functional as F
from einops import rearrange
from transformers import GPT2Config, GPTNeoXConfig
def remap_state_dict_hf_gpt_neox(state_dict, config):
def key_mapping_layers(key):
return re.sub(r'^gpt_neox.', 'transformer.', key)
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding
def key_mapping_emb(key):
return re.sub(r'^transformer.embed_in.', 'transformer.embeddings.word_embeddings.', key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
else:
output_embeddings = state_dict.pop('embed_out.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key)
key = re.sub(r'^transformer.layers.(\d+).input_layernorm.', r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.', r'transformer.layers.\1.norm2.', key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
def key_mapping_mlp(key):
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.', r'transformer.layers.\1.mlp.fc1.', key)
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.', r'transformer.layers.\1.mlp.fc2.', key)
return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
for l in range(config.n_layer):
# We don't store these biases
state_dict.pop(f'transformer.layers.{l}.attention.bias')
state_dict.pop(f'transformer.layers.{l}.attention.masked_bias')
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
headdim = config.hidden_size // config.num_attention_heads
Wqkv = state_dict.pop(f'transformer.layers.{l}.attention.query_key_value.weight')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = rearrange(
Wqkv, '(nheads three headdim) ... -> (three nheads headdim) ...',
three=3, headdim=headdim
)
bqkv = state_dict.pop(f'transformer.layers.{l}.attention.query_key_value.bias')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.bias'] = rearrange(
bqkv, '(nheads three headdim) -> (three nheads headdim)',
three=3, headdim=headdim
)
def key_mapping_attn(key):
key = re.sub(r'^transformer.layers.(\d+).attention.dense.',
r'transformer.layers.\1.mixer.out_proj.', key)
key = re.sub(r'^transformer.layers.(\d+).attention.rotary_emb.',
r'transformer.layers.\1.mixer.rotary_emb.', key)
return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
def gpt_neox_config_to_gpt2_config(gpt_neox_config: GPTNeoXConfig) -> GPT2Config:
assert gpt_neox_config.rotary_emb_base == 10000
return GPT2Config(
vocab_size=gpt_neox_config.vocab_size,
n_positions=0, # No absolute position embedding
n_embd=gpt_neox_config.hidden_size,
n_layer=gpt_neox_config.num_hidden_layers,
n_head=gpt_neox_config.num_attention_heads,
n_inner=gpt_neox_config.intermediate_size,
activation_function=gpt_neox_config.hidden_act,
resid_pdrop=0.0, # No dropout
embd_pdrop=0.0,
attn_pdrop=0.0,
layer_norm_epsilon=gpt_neox_config.layer_norm_eps,
initializer_range=gpt_neox_config.initializer_range,
bos_token_id=gpt_neox_config.bos_token_id,
eos_token_id=gpt_neox_config.eos_token_id,
# These are new arguments not in the original GPT2Config
prenorm=True,
parallel_block=gpt_neox_config.use_parallel_residual,
parallel_block_tied_norm=False,
rotary_emb_fraction=gpt_neox_config.rotary_pct,
tie_word_embeddings=gpt_neox_config.tie_word_embeddings,
)

View File

@@ -0,0 +1,98 @@
# Copyright (c) 2023, Tri Dao.
import math
import re
from collections import OrderedDict
import torch
import torch.nn.functional as F
from transformers import GPT2Config, GPTJConfig
def remap_state_dict_hf_gptj(state_dict, config):
def key_mapping_layers(key):
return re.sub(r'^transformer.h.', 'transformer.layers.', key)
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding
def key_mapping_emb(key):
return re.sub(r'^transformer.wte.', 'transformer.embeddings.word_embeddings.', key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
else:
output_embeddings = state_dict.pop('lm_head.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
output_embeddings_bias = state_dict.pop('lm_head.bias')
state_dict['lm_head.bias'] = F.pad(
output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])
)
# LayerNorm
def key_mapping_ln(key):
return re.sub(r'^transformer.layers.(\d+).ln_1.', r'transformer.layers.\1.norm1.', key)
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
def key_mapping_mlp(key):
key = re.sub(r'^transformer.layers.(\d+).mlp.fc_in.', r'transformer.layers.\1.mlp.fc1.', key)
key = re.sub(r'^transformer.layers.(\d+).mlp.fc_out.', r'transformer.layers.\1.mlp.fc2.', key)
return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
for l in range(config.n_layer):
Wq = state_dict.pop(f'transformer.layers.{l}.attn.q_proj.weight')
Wk = state_dict.pop(f'transformer.layers.{l}.attn.k_proj.weight')
Wv = state_dict.pop(f'transformer.layers.{l}.attn.v_proj.weight')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
# We don't store these biases
state_dict.pop(f'transformer.layers.{l}.attn.bias')
state_dict.pop(f'transformer.layers.{l}.attn.masked_bias')
def key_mapping_attn(key):
return re.sub(r'^transformer.layers.(\d+).attn.out_proj.',
r'transformer.layers.\1.mixer.out_proj.', key)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
def gptj_config_to_gpt2_config(gptj_config: GPTJConfig) -> GPT2Config:
headdim = gptj_config.n_embd // gptj_config.n_head
return GPT2Config(
vocab_size=gptj_config.vocab_size,
n_positions=0, # No absolute position embedding
n_embd=gptj_config.n_embd,
n_layer=gptj_config.n_layer,
n_head=gptj_config.n_head,
n_inner=gptj_config.n_inner,
activation_function=gptj_config.activation_function,
resid_pdrop=gptj_config.resid_pdrop,
embd_pdrop=gptj_config.embd_pdrop,
attn_pdrop=gptj_config.attn_pdrop,
layer_norm_epsilon=gptj_config.layer_norm_epsilon,
initializer_range=gptj_config.initializer_range,
bos_token_id=gptj_config.bos_token_id,
eos_token_id=gptj_config.eos_token_id,
# These are new arguments not in the original GPT2Config
prenorm=True,
parallel_block=True,
parallel_block_tied_norm=True,
rotary_emb_fraction=gptj_config.rotary_dim / headdim,
rotary_emb_interleaved=True,
tie_word_embeddings=False,
qkv_proj_bias=False,
out_proj_bias=False,
lm_head_bias=True,
)

View File

@@ -0,0 +1,124 @@
# Copyright (c) 2023, Tri Dao.
import math
import json
import re
from pathlib import Path
from collections import OrderedDict
import torch
import torch.nn.functional as F
from transformers import GPT2Config, LlamaConfig
def remap_state_dict_meta_llama(state_dict, config):
def key_mapping_layers(key):
return f'transformer.{key}' if not key.startswith('output.') else key
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding
def key_mapping_emb(key):
return re.sub(r'^transformer.tok_embeddings.', 'transformer.embeddings.word_embeddings.', key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
else:
output_embeddings = state_dict.pop('output.weight')
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
# differently.
vocab_size = (math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^transformer.norm.', r'transformer.ln_f.', key)
key = re.sub(r'^transformer.layers.(\d+).attention_norm.', r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).ffn_norm.', r'transformer.layers.\1.norm2.', key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
for l in range(config.n_layer):
w1 = state_dict.pop(f'transformer.layers.{l}.feed_forward.w1.weight')
w3 = state_dict.pop(f'transformer.layers.{l}.feed_forward.w3.weight')
# Our ordering is different
state_dict[f'transformer.layers.{l}.mlp.fc1.weight'] = torch.cat([w3, w1], dim=0)
def key_mapping_mlp(key):
return re.sub(r'^transformer.layers.(\d+).feed_forward.w2.',
r'transformer.layers.\1.mlp.fc2.', key)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
for l in range(config.n_layer):
Wq = state_dict.pop(f'transformer.layers.{l}.attention.wq.weight')
Wk = state_dict.pop(f'transformer.layers.{l}.attention.wk.weight')
Wv = state_dict.pop(f'transformer.layers.{l}.attention.wv.weight')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
# We don't store these
state_dict.pop(f'transformer.layers.{l}.attention.inner_attention.rope.freqs', None)
def key_mapping_attn(key):
return re.sub(r'^transformer.layers.(\d+).attention.wo.',
r'transformer.layers.\1.mixer.out_proj.', key)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
def config_from_checkpoint(checkpoint_path: str, model_name: str) -> LlamaConfig:
"""Load a LlamaConfig from a checkpoint path."""
with open(Path(checkpoint_path) / model_name / 'params.json') as f:
params = json.load(f)
config = LlamaConfig(hidden_size=params['dim'], intermediate_size=None,
num_attention_heads=params['n_heads'],
num_hidden_layers=params['n_layers'],
rms_norm_eps=params['norm_eps'])
return config
def state_dicts_from_checkpoint(checkpoint_path: str, model_name: str) -> dict:
# Need to sort, otherwise we mess up the ordering and the weights are wrong
return [torch.load(path, map_location='cpu')
for path in sorted((Path(checkpoint_path) / model_name).glob('consolidated.*.pth'))]
def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
return GPT2Config(
vocab_size=llama_config.vocab_size,
n_positions=0, # No absolute position embedding
n_embd=llama_config.hidden_size,
n_layer=llama_config.num_hidden_layers,
n_head=llama_config.num_attention_heads,
n_inner=llama_config.intermediate_size,
activation_function='swiglu', # Hardcode since HF calls it 'silu'
# Llama doesn't have dropout, idk if it's because they only release the inference code
resid_pdrop=0.0,
embd_pdrop=0.0,
attn_pdrop=0.0,
layer_norm_epsilon=llama_config.rms_norm_eps,
initializer_range=llama_config.initializer_range,
bos_token_id=llama_config.bos_token_id,
eos_token_id=llama_config.eos_token_id,
# These are new arguments not in the original GPT2Config
pad_token_id=llama_config.pad_token_id, # Idk if this does anything
rms_norm=True,
rotary_emb_fraction=1.0,
rotary_emb_interleaved=True,
tie_word_embeddings=False,
qkv_proj_bias=False,
out_proj_bias=False,
mlp_fc1_bias=False,
mlp_fc2_bias=False,
)

View File

@@ -0,0 +1,102 @@
# Copyright (c) 2023, Tri Dao.
import math
import re
from collections import OrderedDict
import torch
import torch.nn.functional as F
from transformers import GPT2Config, OPTConfig
def remap_state_dict_hf_opt(state_dict, config):
def key_mapping_model(key):
key = re.sub(r'^model.decoder.', 'transformer.', key)
# The OPT-350m model uses '^decoder' instead of '^model.decoder'
key = re.sub(r'^decoder.', 'transformer.', key)
return key
state_dict = OrderedDict((key_mapping_model(k), v) for k, v in state_dict.items())
# Word embedding and position embedding
def key_mapping_emb(key):
key = re.sub(r'^transformer.embed_tokens.', 'transformer.embeddings.word_embeddings.', key)
# The OPT-350m model uses has project_in and project_out
key = re.sub(r'^transformer.project_in.', 'transformer.embeddings.project_in.', key)
key = re.sub(r'^transformer.project_out.', 'project_out.', key)
key = re.sub(r'^transformer.embed_positions.',
'transformer.embeddings.position_embeddings.', key)
return key
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
# OPT uses the first 2 indices of pos_emb for padding tokens
pos_embeddings = state_dict.pop('transformer.embeddings.position_embeddings.weight')
state_dict['transformer.embeddings.position_embeddings.weight'] = pos_embeddings[2:]
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key)
# The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm'
key = re.sub(r'^transformer.layer_norm.', r'transformer.ln_f.', key)
key = re.sub(r'^transformer.layers.(\d+).self_attn_layer_norm.',
r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).final_layer_norm.',
r'transformer.layers.\1.norm2.', key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
def key_mapping_mlp(key):
return re.sub(r'^transformer.layers.(\d+).fc(1|2).',
r'transformer.layers.\1.mlp.fc\2.', key)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
for l in range(config.n_layer):
Wq = state_dict.pop(f'transformer.layers.{l}.self_attn.q_proj.weight')
Wk = state_dict.pop(f'transformer.layers.{l}.self_attn.k_proj.weight')
Wv = state_dict.pop(f'transformer.layers.{l}.self_attn.v_proj.weight')
bq = state_dict.pop(f'transformer.layers.{l}.self_attn.q_proj.bias')
bk = state_dict.pop(f'transformer.layers.{l}.self_attn.k_proj.bias')
bv = state_dict.pop(f'transformer.layers.{l}.self_attn.v_proj.bias')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
state_dict[f'transformer.layers.{l}.mixer.Wqkv.bias'] = torch.cat([bq, bk, bv], dim=0)
def key_mapping_attn(key):
return re.sub(r'^transformer.layers.(\d+).self_attn.out_proj.',
r'transformer.layers.\1.mixer.out_proj.', key)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config:
assert opt_config.layerdrop == 0.0
assert opt_config.layer_norm_elementwise_affine
word_embed_proj_dim = (None if opt_config.word_embed_proj_dim == opt_config.hidden_size
else opt_config.word_embed_proj_dim)
return GPT2Config(
vocab_size=opt_config.vocab_size,
n_positions=opt_config.max_position_embeddings,
n_embd=opt_config.hidden_size,
n_layer=opt_config.num_hidden_layers,
n_head=opt_config.num_attention_heads,
n_inner=opt_config.ffn_dim,
activation_function=opt_config.activation_function,
resid_pdrop=opt_config.dropout,
# HF's implementation of OPT doesn't seem to have embedding dropout
embd_pdrop=opt_config.dropout,
attn_pdrop=opt_config.attention_dropout,
initializer_range=opt_config.init_std,
bos_token_id=opt_config.bos_token_id,
eos_token_id=opt_config.eos_token_id,
# These are new arguments not in the original GPT2Config
prenorm=opt_config.do_layer_norm_before,
word_embed_proj_dim=word_embed_proj_dim
)

View File

@@ -0,0 +1,304 @@
# Copyright (c) 2022, Tri Dao.
# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
import math
import re
from functools import partial
from copy import deepcopy
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import trunc_normal_
from torchvision.ops import StochasticDepth
from einops import rearrange
from timm.models.helpers import named_apply
from flash_attn.layers.patch_embed import PatchEmbed
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp, FusedMLP
from flash_attn.modules.block import Block
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
dropout_add_layer_norm = None
def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc,
cross_attn=False):
mixer_cls = partial(MHA, num_heads=num_heads, cross_attn=cross_attn, bias=qkv_bias,
dropout=attn_drop, fused_bias_fc=fused_bias_fc,
use_flash_attn=use_flash_attn)
return mixer_cls
def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp):
inner_dim = int(embed_dim * mlp_ratio)
if not fused_mlp:
mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=act_layer())
else:
mlp_cls = partial(FusedMLP, hidden_features=inner_dim)
return mlp_cls
def create_block(embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate,
drop_path1, drop_path2, norm_layer, act_layer, use_flash_attn, fused_bias_fc,
fused_mlp, fused_dropout_add_ln, layer_idx=None, n_layer=None,
last_layer_subset=False):
mixer_cls = create_mixer_cls(num_heads, qkv_bias, attn_drop_rate, use_flash_attn, fused_bias_fc,
cross_attn=(last_layer_subset and layer_idx == n_layer - 1))
mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp)
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
block = Block(embed_dim, mixer_cls, mlp_cls, norm_cls=norm_layer,
prenorm=True, resid_dropout1=drop_rate, resid_dropout2=drop_rate,
drop_path1=drop_path1, drop_path2=drop_path2,
fused_dropout_add_ln=fused_dropout_add_ln, residual_in_fp32=True)
return block
class VisionTransformer(nn.Module):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
global_pool='token',
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
init_values=None,
class_token=True,
no_embed_class=False,
pre_norm=False,
fc_norm=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
weight_init='',
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None,
use_flash_attn=False,
fused_bias_fc=False,
fused_mlp=False,
fused_dropout_add_ln=False,
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
global_pool (str): type of global pooling for final sequence (default: 'token')
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
init_values: (float): layer-scale init values
class_token (bool): use class token
fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
weight_init (str): weight init scheme
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
act_layer: (nn.Module): MLP activation layer
"""
super().__init__()
assert global_pool == 'token', 'Only support pooling with CLS token'
assert class_token
assert init_values is None, 'LayerScale is not supported yet'
assert weight_init == ''
assert fc_norm is None
# pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk
assert not pre_norm
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_prefix_tokens = 1 if class_token else 0
self.no_embed_class = no_embed_class
patch_embed_extra_kwargs = ({'fused_bias_fc': fused_bias_fc} if embed_layer is PatchEmbed
else {})
self.patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
**patch_embed_extra_kwargs
)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
# We change the order of dropout, residual and layer norm:
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
# Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
# the main branch (output of MLP). The model definition is unchanged, but the mapping of the
# nn.Dropout probabilities are changed.
# This is for performance reason: we can fuse dropout + add + layer_norm.
self.blocks = nn.ModuleList([create_block(
embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate,
drop_path1=dpr[i-1] if i > 0 else 0., drop_path2=dpr[i],
norm_layer=norm_layer, act_layer=act_layer, use_flash_attn=use_flash_attn,
fused_bias_fc=fused_bias_fc, fused_mlp=fused_mlp,
fused_dropout_add_ln=fused_dropout_add_ln, layer_idx=i, n_layer=depth,
last_layer_subset=(global_pool == 'token')
) for i in range(depth)])
self.dropout = nn.Dropout(p=drop_rate)
self.drop_path = StochasticDepth(p=dpr[-1], mode='row')
self.norm = norm_layer(embed_dim)
self.fused_dropout_add_ln = fused_dropout_add_ln
if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
raise ImportError('dropout_add_layer_norm is not installed')
# Classifier Head
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.init_weights(weight_init)
def init_weights(self, mode=''):
assert mode == ''
trunc_normal_(self.pos_embed, std=.02)
if self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6)
named_apply(init_weights_vit_timm, self)
def _init_weights(self, m):
# this fn left here for compat with downstream users
init_weights_vit_timm(m)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def _pos_embed(self, x):
if self.no_embed_class:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x = x + self.pos_embed
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
else:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.pos_embed
return x
def forward_features(self, x, all_tokens=True):
"""
If all_tokens==False and self.global_pool == 'token', we only return the features for the
cls token.
"""
x = self.patch_embed(x)
hidden_states = self._pos_embed(x)
residual = None
if self.global_pool != 'token' or all_tokens:
# if True:
for block in self.blocks:
hidden_states, residual = block(hidden_states, residual)
else:
for block in self.blocks[:-1]:
hidden_states, residual = block(hidden_states, residual)
# For the last layer, we only want the 1st token of the output. So we do cross-attention
# where the query is the 1st token and the key/value is the whole sequence.
hidden_states, residual = self.blocks[-1](hidden_states, residual,
mixer_subset=slice(0, 1))
if not self.fused_dropout_add_ln:
residual = self.drop_path(self.dropout(hidden_states)) + residual
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
else:
if self.drop_path.p == 0 or not self.training:
rowscale = None
else:
rowscale = self.drop_path(torch.ones(
hidden_states.shape[:-1], device=hidden_states.device,
dtype=hidden_states.dtype)
)
# Set prenorm=False here since we don't need to the residual
hidden_states = dropout_add_layer_norm(
hidden_states, residual, self.norm.weight, self.norm.bias,
self.dropout.p if self.training else 0.0, self.norm.eps, rowscale=rowscale,
prenorm=False, residual_in_fp32=True
)
return hidden_states
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x, all_tokens=False)
x = self.forward_head(x)
return x
def load_state_dict(self, state_dict, strict=True):
patch_embed_weight = state_dict['patch_embed.proj.weight']
if patch_embed_weight.dim() == 4:
# convert from Conv2d to Linear
state_dict['patch_embed.proj.weight'] = rearrange(patch_embed_weight,
'o c h w -> o (c h w)')
def key_mapping_attn(key):
key = re.sub(r'^blocks.(\d+).attn.qkv.', r'blocks.\1.mixer.Wqkv.', key)
key = re.sub(r'^blocks.(\d+).attn.proj.', r'blocks.\1.mixer.out_proj.', key)
return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
n_layer = len(self.blocks)
# Convert from Wqkv to Wq and Wkv for cross attention (last layer)
if (self.blocks[-1].mixer.cross_attn
and f'blocks.{n_layer - 1}.mixer.Wqkv.weight' in state_dict):
Wqkv = state_dict.pop(f'blocks.{n_layer - 1}.mixer.Wqkv.weight')
bqkv = state_dict.pop(f'blocks.{n_layer - 1}.mixer.Wqkv.bias')
state_dict[f'blocks.{n_layer - 1}.mixer.Wq.weight'] = Wqkv[:self.embed_dim]
state_dict[f'blocks.{n_layer - 1}.mixer.Wkv.weight'] = Wqkv[self.embed_dim:]
state_dict[f'blocks.{n_layer - 1}.mixer.Wq.bias'] = bqkv[:self.embed_dim]
state_dict[f'blocks.{n_layer - 1}.mixer.Wkv.bias'] = bqkv[self.embed_dim:]
return super().load_state_dict(state_dict, strict=strict)
def init_weights_vit_timm(module: nn.Module, name: str = ''):
""" ViT weight initialization, original timm impl (for reproducibility) """
if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'):
module.init_weights()
def vit_base_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
assert not pretrained
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = VisionTransformer(**model_kwargs)
return model

View File

@@ -0,0 +1,324 @@
# Copyright (c) 2022, Tri Dao.
from typing import Optional
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torchvision.ops import StochasticDepth
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
dropout_add_layer_norm = None
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
except ImportError:
dropout_add_layer_norm_parallel_residual = None
try:
from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm
except ImportError:
RMSNorm, dropout_add_rms_norm = None, None
try:
from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
except ImportError:
dropout_add_rms_norm_parallel_residual = None
class Block(nn.Module):
def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
dropout_cls=nn.Dropout, prenorm=True, resid_dropout1=0., resid_dropout2=0.,
drop_path1=0., drop_path2=0., fused_dropout_add_ln=False, return_residual=False,
residual_in_fp32=False, sequence_parallel=False, mark_shared_params=False):
"""
For prenorm=True, this Block has a slightly different structure compared to a regular
prenorm Transformer block.
The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
[Ref: https://arxiv.org/abs/2002.04745]
Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
the hidden_states (output of the MLP) and the residual.
This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
The residual needs to be provided (except for the very first block).
For prenorm=False, this Block has the same structure as a regular postnorm Transformer
block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
This is for performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
super().__init__()
self.prenorm = prenorm
self.fused_dropout_add_ln = fused_dropout_add_ln
self.return_residual = return_residual
self.residual_in_fp32 = residual_in_fp32
if self.residual_in_fp32:
assert self.prenorm, 'residual_in_fp32 is only compatible with prenorm=True'
if mixer_cls is None:
mixer_cls = partial(MHA, num_heads=dim // 64)
if mlp_cls is None:
mlp_cls = partial(Mlp, hidden_features=4 * dim)
self.mixer = mixer_cls(dim)
self.dropout1 = dropout_cls(resid_dropout1)
self.drop_path1 = StochasticDepth(drop_path1, mode='row')
self.norm1 = norm_cls(dim)
self.mlp = mlp_cls(dim)
if not isinstance(self.mlp, nn.Identity):
self.dropout2 = dropout_cls(resid_dropout2)
self.drop_path2 = StochasticDepth(drop_path2, mode='row')
self.norm2 = norm_cls(dim)
if self.fused_dropout_add_ln:
assert dropout_add_layer_norm is not None, 'dropout_layer_norm is not installed'
assert dropout_add_rms_norm is not None, 'dropout_layer_norm is not installed'
assert (isinstance(self.norm1, (nn.LayerNorm, RMSNorm))
and isinstance(self.dropout1, nn.Dropout))
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# then the input to each worker in the tensor parallel group will be different.
# This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
# For now this is not an issue because we always use sequence_parallel=True during training
# and only use sequence_parallel=False during inference.
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
if sequence_parallel:
for p in self.norm1.parameters():
p._sequence_parallel = True
if hasattr(self, 'norm2'):
for p in self.norm2.parameters():
p._sequence_parallel = True
# Mark the norm parameters as "shared_params" so that we sync their values at init.
if mark_shared_params:
for p in self.norm1.parameters():
p._shared_params = True
if hasattr(self, 'norm2'):
for p in self.norm2.parameters():
p._shared_params = True
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None,
mixer_subset=None, mixer_kwargs=None):
r"""Pass the input through the encoder layer.
Args:
hidden_states: the sequence to the encoder layer (required).
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
mixer_subset: for cross-attention only. If not None, will take a subset of x
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
"""
fused_add_norm_fn = (dropout_add_rms_norm if RMSNorm and isinstance(self.norm1, RMSNorm)
else dropout_add_layer_norm)
if self.prenorm:
if not self.fused_dropout_add_ln:
dropped = self.drop_path1(self.dropout1(hidden_states))
residual = (dropped + residual) if residual is not None else dropped
hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
if self.drop_path1.p == 0 or not self.training:
rowscale1 = None
else:
rowscale1 = self.drop_path1(torch.ones(
hidden_states.shape[:-1], device=hidden_states.device,
dtype=hidden_states.dtype)
)
hidden_states, residual = fused_add_norm_fn(
hidden_states, residual, self.norm1.weight, self.norm1.bias,
self.dropout1.p if self.training else 0.0, self.norm1.eps,
rowscale=rowscale1, prenorm=True, residual_in_fp32=self.residual_in_fp32
)
if mixer_kwargs is None:
mixer_kwargs = {}
if mixer_subset is not None:
mixer_kwargs['mixer_subset'] = mixer_subset
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
if mixer_subset is not None:
residual = residual[:, mixer_subset]
if not isinstance(self.mlp, nn.Identity):
if not self.fused_dropout_add_ln:
dropped = self.drop_path2(self.dropout2(hidden_states))
residual = (dropped + residual) if residual is not None else dropped
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
if self.drop_path2.p == 0 or not self.training:
rowscale2 = None
else:
rowscale2 = self.drop_path2(torch.ones(
hidden_states.shape[:-1], device=hidden_states.device,
dtype=hidden_states.dtype)
)
hidden_states, residual = fused_add_norm_fn(
hidden_states, residual, self.norm2.weight, self.norm2.bias,
self.dropout2.p if self.training else 0.0, self.norm2.eps,
rowscale=rowscale2, prenorm=True, residual_in_fp32=self.residual_in_fp32
)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
else:
assert residual is None
mixer_out = self.mixer(
hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
)
if self.return_residual: # mixer out is actually a pair here
mixer_out, hidden_states = mixer_out
if not self.fused_dropout_add_ln:
hidden_states = self.norm1((self.drop_path1(self.dropout1(mixer_out))
+ hidden_states).to(dtype=self.norm1.weight.dtype))
else:
if self.drop_path1.p == 0 or not self.training:
rowscale1 = None
else:
rowscale1 = self.drop_path1(torch.ones(
mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype)
)
hidden_states = fused_add_norm_fn(
mixer_out, hidden_states, self.norm1.weight, self.norm1.bias,
self.dropout1.p if self.training else 0.0, self.norm1.eps,
rowscale=rowscale1, prenorm=False
)
if not isinstance(self.mlp, nn.Identity):
mlp_out = self.mlp(hidden_states)
if self.return_residual: # mlp out is actually a pair here
mlp_out, hidden_states = mlp_out
if not self.fused_dropout_add_ln:
hidden_states = self.norm2((self.drop_path2(self.dropout2(mlp_out))
+ hidden_states).to(dtype=self.norm2.weight.dtype))
else:
if self.drop_path2.p == 0 or not self.training:
rowscale2 = None
else:
rowscale2 = self.drop_path2(torch.ones(
mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype)
)
hidden_states = fused_add_norm_fn(
mlp_out, hidden_states, self.norm2.weight, self.norm2.bias,
self.dropout2.p if self.training else 0.0, self.norm2.eps,
rowscale=rowscale2, prenorm=False
)
return hidden_states
class ParallelBlock(nn.Module):
"""The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
and PaLM.
"""
def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
dropout_cls=nn.Dropout, resid_dropout1=0., resid_dropout2=0.,
tied_norm=False, fused_dropout_add_ln=False, residual_in_fp32=False,
sequence_parallel=False, mark_shared_params=False):
"""
This Block has a slightly different structure compared to a regular
prenorm Transformer block.
The standard block is: LN -> MHA / MLP -> Dropout -> Add.
[Ref: https://arxiv.org/abs/2002.04745]
Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
the hidden_states (output1 of the MHA / MLP) and the residual.
This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
The residual needs to be provided (except for the very first block).
"""
super().__init__()
self.tied_norm = tied_norm
self.fused_dropout_add_ln = fused_dropout_add_ln
self.residual_in_fp32 = residual_in_fp32
if mixer_cls is None:
mixer_cls = partial(MHA, num_heads=dim // 64)
if mlp_cls is None:
mlp_cls = partial(Mlp, hidden_features=4 * dim)
self.mixer = mixer_cls(dim)
self.dropout1 = dropout_cls(resid_dropout1)
self.norm1 = norm_cls(dim)
self.mlp = mlp_cls(dim)
self.dropout2 = dropout_cls(resid_dropout2)
if not self.tied_norm:
self.norm2 = norm_cls(dim)
if self.fused_dropout_add_ln:
assert dropout_add_layer_norm_parallel_residual is not None, 'dropout_layer_norm is not installed'
assert dropout_add_rms_norm_parallel_residual is not None, 'dropout_layer_norm is not installed'
assert (isinstance(self.norm1, (nn.LayerNorm, RMSNorm))
and isinstance(self.dropout1, nn.Dropout))
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# then the input to each worker in the tensor parallel group will be different.
# This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
# For now this is not an issue because we always use sequence_parallel=True during training
# and only use sequence_parallel=False during inference.
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
if sequence_parallel:
for p in self.norm1.parameters():
p._sequence_parallel = True
if hasattr(self, 'norm2'):
for p in self.norm2.parameters():
p._sequence_parallel = True
# Mark the norm parameters as "shared_params" so that we sync their values at init.
if mark_shared_params:
for p in self.norm1.parameters():
p._shared_params = True
if hasattr(self, 'norm2'):
for p in self.norm2.parameters():
p._shared_params = True
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
def forward(self, hidden_states1: Tensor, hidden_states2: Optional[Tensor] = None,
residual: Optional[Tensor] = None, mixer_kwargs=None):
r"""Pass the input through the encoder layer.
Args:
hidden_states1: the output of the previous attention (mixer) or embedding layer.
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
residual.
"""
# TODO: Ideally we should only do the allgather / allreduce once for
# the Linear to MLP & Attention
fused_add_norm_fn = (dropout_add_rms_norm_parallel_residual
if isinstance(self.norm1, RMSNorm)
else dropout_add_layer_norm_parallel_residual)
if not self.fused_dropout_add_ln:
dropped1 = self.dropout1(hidden_states1)
# For the very 1st block, we only want 1 dropout, not two different dropouts
if hidden_states2 is not None:
dropped2 = self.dropout2(hidden_states2)
residual = ((residual + dropped1 + dropped2)
if residual is not None else dropped1 + dropped2)
else:
residual = (residual + dropped1) if residual is not None else dropped1
hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
hidden_states2 = (self.norm2(residual.to(dtype=self.norm2.weight.dtype))
if not self.tied_norm else hidden_states1)
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
weight2, bias2 = ((self.norm2.weight, self.norm2.bias)
if not self.tied_norm else (None, None))
hidden_states1, hidden_states2, residual = fused_add_norm_fn(
hidden_states1, hidden_states2, residual, self.norm1.weight, self.norm1.bias,
weight2, bias2, self.dropout1.p if self.training else 0.0, self.norm1.eps,
prenorm=True, residual_in_fp32=self.residual_in_fp32
)
if self.tied_norm:
hidden_states2 = hidden_states1
if mixer_kwargs is None:
mixer_kwargs = {}
hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
hidden_states2 = self.mlp(hidden_states2)
return hidden_states1, hidden_states2, residual

View File

@@ -0,0 +1,183 @@
# Copyright (c) 2022, Tri Dao.
import torch
import torch.nn as nn
from torch import Tensor
from einops import rearrange
from flash_attn.utils.distributed import reduce_scatter, all_reduce
class GPT2Embeddings(nn.Module):
def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None,
word_embed_proj_dim=None, device=None, dtype=None):
"""
If max_position_embeddings <= 0, there's no position embeddings
If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
the project up to embed_dim
"""
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
if word_embed_proj_dim is None:
self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
**factory_kwargs)
self.project_in = None
else:
self.word_embeddings = nn.Embedding(vocab_size, word_embed_proj_dim,
padding_idx=padding_idx, **factory_kwargs)
self.project_in = nn.Linear(word_embed_proj_dim, embed_dim, bias=False,
**factory_kwargs)
self.max_position_embeddings = max_position_embeddings
if self.max_position_embeddings > 0:
self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim,
**factory_kwargs)
def forward(self, input_ids, position_ids=None):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
"""
batch_size, seqlen = input_ids.shape
embeddings = self.word_embeddings(input_ids)
if self.project_in is not None:
embeddings = self.project_in(embeddings)
if self.max_position_embeddings > 0:
if position_ids is None:
position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
position_embeddings = self.position_embeddings(position_ids)
embeddings = embeddings + position_embeddings
return embeddings
class BertEmbeddings(nn.Module):
def __init__(self, embed_dim, vocab_size, max_position_embeddings, type_vocab_size,
padding_idx=None, device=None, dtype=None):
"""
If max_position_embeddings <= 0, there's no position embeddings
If type_vocab_size <= 0, there's no token type embeddings
"""
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
**factory_kwargs)
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
if self.max_position_embeddings > 0:
self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim,
**factory_kwargs)
if self.type_vocab_size > 0:
self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim,
**factory_kwargs)
def forward(self, input_ids, position_ids=None, token_type_ids=None):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
token_type_ids: (batch, seqlen)
"""
batch_size, seqlen = input_ids.shape
embeddings = self.word_embeddings(input_ids)
if self.max_position_embeddings > 0:
if position_ids is None:
position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
position_embeddings = self.position_embeddings(position_ids)
embeddings = embeddings + position_embeddings
if self.type_vocab_size > 0:
if token_type_ids is None:
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = embeddings + token_type_embeddings
return embeddings
class VocabParallelEmbedding(nn.Embedding):
def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):
self.process_group = process_group
if process_group is not None:
world_size = torch.distributed.get_world_size(process_group)
if num_embeddings % world_size != 0:
raise ValueError(f'num_embeddings ({num_embeddings}) must be divisible by '
f'world_size ({world_size})')
if world_size > 1 and padding_idx is not None:
raise RuntimeError('ParallelEmbedding does not support padding_idx')
else:
world_size = 1
super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)
def forward(self, input: Tensor) -> Tensor:
if self.process_group is None:
return super().forward(input)
else:
rank = torch.distributed.get_rank(self.process_group)
vocab_size = self.num_embeddings
vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
# Create a mask of valid vocab ids (1 means it needs to be masked).
input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
input = input - vocab_start_index
input[input_ids_mask] = 0
embeddings = super().forward(input)
embeddings[input_ids_mask] = 0.0
return embeddings
class ColumnParallelEmbedding(nn.Embedding):
def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):
self.process_group = process_group
if process_group is not None:
world_size = torch.distributed.get_world_size(process_group)
if embedding_dim % world_size != 0:
raise ValueError(f'embedding_dim ({embedding_dim}) must be divisible by '
f'world_size ({world_size})')
else:
world_size = 1
super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
class ParallelGPT2Embeddings(nn.Module):
def __init__(self, embed_dim, vocab_size, max_position_embeddings, process_group,
padding_idx=None, sequence_parallel=True, device=None, dtype=None):
"""
If max_position_embeddings <= 0, there's no position embeddings
"""
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.process_group = process_group
self.sequence_parallel = sequence_parallel
self.word_embeddings = VocabParallelEmbedding(
vocab_size, embed_dim, padding_idx=padding_idx, process_group=process_group,
**factory_kwargs
)
self.max_position_embeddings = max_position_embeddings
if self.max_position_embeddings > 0:
self.position_embeddings = ColumnParallelEmbedding(
max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs
)
def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
"""
batch_size, seqlen = input_ids.shape
world_size = torch.distributed.get_world_size(self.process_group)
embeddings = self.word_embeddings(input_ids)
if self.max_position_embeddings > 0:
if position_ids is None:
position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
position_embeddings = self.position_embeddings(position_ids)
if world_size <= 1:
embeddings = embeddings + position_embeddings
else:
partition_dim = self.position_embeddings.embedding_dim
rank = torch.distributed.get_rank(self.process_group)
embeddings[..., rank * partition_dim:(rank + 1) * partition_dim] += position_embeddings
if combine_batch_seqlen_dim:
embeddings = rearrange(embeddings, 'b s d -> (b s) d')
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)

View File

@@ -0,0 +1,711 @@
# Copyright (c) 2022, Tri Dao.
import math
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
try:
from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func
from flash_attn import flash_attn_qkvpacked_func, flash_attn_kvpacked_func
except ImportError:
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
try:
from flash_attn.ops.fused_dense import FusedDense, ColumnParallelLinear, RowParallelLinear
except ImportError:
FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
try:
from flash_attn.layers.rotary import RotaryEmbedding
except ImportError:
RotaryEmbedding = None
try:
import ft_attention
except ImportError:
ft_attention = None
class FlashSelfAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
super().__init__()
assert flash_attn_varlen_qkvpacked_func is not None, 'FlashAttention is not installed'
assert flash_attn_qkvpacked_func is not None, 'FlashAttention is not installed'
self.causal = causal
self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout)
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value.
If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
(total, 3, H, D), where total is the sum of the sequence lengths in the batch.
causal: if passed, will override self.causal
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into qkv.
max_seqlen: int. Maximum sequence length in the batch.
Returns:
--------
out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
else (B, S, H, D).
"""
assert qkv.dtype in [torch.float16, torch.bfloat16]
assert qkv.is_cuda
causal = self.causal if causal is None else causal
unpadded = cu_seqlens is not None
if unpadded:
assert cu_seqlens.dtype == torch.int32
assert max_seqlen is not None
assert isinstance(max_seqlen, int)
return flash_attn_varlen_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, self.drop.p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
else:
return flash_attn_qkvpacked_func(qkv, self.drop.p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal)
class FlashCrossAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
super().__init__()
assert flash_attn_varlen_kvpacked_func is not None, 'FlashAttention is not installed'
assert flash_attn_kvpacked_func is not None, 'FlashAttention is not installed'
self.causal = causal
self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout)
def forward(self, q, kv, causal=None, cu_seqlens=None, max_seqlen=None,
cu_seqlens_k=None, max_seqlen_k=None):
"""Implements the multihead softmax attention.
Arguments
---------
q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
causal: if passed, will override self.causal
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
max_seqlen: int. Maximum sequence length in the batch of q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_k: int. Maximum sequence length in the batch of k and v.
"""
assert q.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda and kv.is_cuda
causal = self.causal if causal is None else causal
unpadded = cu_seqlens is not None
if unpadded:
assert cu_seqlens.dtype == torch.int32
assert max_seqlen is not None
assert isinstance(max_seqlen, int)
assert cu_seqlens_k is not None
assert cu_seqlens_k.dtype == torch.int32
assert max_seqlen_k is not None
assert isinstance(max_seqlen, int)
return flash_attn_varlen_kvpacked_func(
q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k,
self.drop.p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
else:
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = kv.shape[1]
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
return flash_attn_kvpacked_func(q, kv, self.drop.p if self.training else 0.0,
causal=causal, softmax_scale=self.softmax_scale)
class SelfAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
super().__init__()
self.causal = causal
self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout)
def forward(self, qkv, causal=None, key_padding_mask=None):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
causal: if passed, will override self.causal
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, S)
"""
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
causal = self.causal if causal is None else causal
q, k, v = qkv.unbind(dim=2)
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
if key_padding_mask is not None:
padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype,
device=scores.device)
padding_mask.masked_fill_(key_padding_mask, 0.0)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
if causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + causal_mask.to(dtype=scores.dtype)
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
attention_drop = self.drop(attention)
output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
return output
class CrossAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
super().__init__()
self.causal = causal
self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout)
def forward(self, q, kv, causal=None, key_padding_mask=None):
"""Implements the multihead softmax attention.
Arguments
---------
q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
causal: if passed, will override self.causal
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, Sk)
"""
batch_size, seqlen_q = q.shape[0], q.shape[1]
causal = self.causal if causal is None else causal
seqlen_k = kv.shape[1]
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
if kv.shape[3] != q.shape[2]: # MQA/GQA
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
k, v = kv.unbind(dim=2)
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
if key_padding_mask is not None:
padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype,
device=scores.device)
padding_mask.masked_fill_(key_padding_mask, 0.0)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
if causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0,
device=scores.device), 1)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + causal_mask.to(dtype=scores.dtype)
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
attention_drop = self.drop(attention)
output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
return output
class LinearResidual(nn.Linear):
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return super().forward(input), input
def _update_kv_cache(kv, inference_params, layer_idx):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
"""
# Pre-allocate memory for key-values for inference.
num_heads, head_dim = kv.shape[-2:]
if layer_idx not in inference_params.key_value_memory_dict:
kv_cache = torch.empty(
inference_params.max_batch_size, inference_params.max_sequence_len, 2,
num_heads, head_dim, dtype=kv.dtype, device=kv.device
)
inference_params.key_value_memory_dict[layer_idx] = kv_cache
else:
if not inference_params.fused_ft_kernel:
kv_cache = inference_params.key_value_memory_dict[layer_idx]
else:
# For FT, k_cache has shape (b, h, headdim / packsize, s, packsize)
# where packsize = 4 if fp32, 8 if fp16 or bf16.
# v_cache has shape (b, h, s, headdim)
k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
kv_cache = None
# Adjust key and value for inference
batch_start = inference_params.batch_size_offset
batch_end = batch_start + kv.shape[0]
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + kv.shape[1]
assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
# Copy key and values.
if not inference_params.fused_ft_kernel:
assert kv_cache is not None
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
return kv
else:
assert inference_params.sequence_len_offset == 0
# FT kernel requires different layouts for the k_cache and v_cache.
assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if kv.dtype == torch.float32 else 8
if kv_cache is not None:
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
k_cache = rearrange(kv_cache[:, :, 0], 'b s h (d packsize) -> b h d s packsize',
packsize=packsize).contiguous()
v_cache = rearrange(kv_cache[:, :, 1], 'b s h d -> b h s d').contiguous()
inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache)
else:
k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
kv[:, :, 0], 'b s h (d packsize) -> b h d s packsize', packsize=packsize
)
v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(
kv[:, :, 1], 'b s h d -> b h s d'
)
return kv
def _apply_rotary_single_query_attention(qkv, inference_params, layer_idx, rotary_emb_dim,
rotary_emb_base, kv=None, rotary_emb_interleaved=False):
"""
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
q of shape (batch_size, 1, nheads, head_dim)
kv: (batch_size, 1, 2, nheads_kv, head_dim)
"""
assert inference_params.fused_ft_kernel
assert ft_attention is not None
if kv is None:
q, k, v = rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1)
else:
q = rearrange(qkv, 'b 1 h d -> b h d')
k, v = rearrange(kv, 'b 1 two h d -> b two h d').unbind(dim=1)
batch_start = inference_params.batch_size_offset
batch_end = batch_start + q.shape[0]
k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
lengths_per_sample = (inference_params.lengths_per_sample[batch_start:batch_end]
if inference_params.lengths_per_sample is not None else None)
context = ft_attention.single_query_attention(
q, k, v,
k_cache[batch_start:batch_end],
v_cache[batch_start:batch_end],
lengths_per_sample,
None, # rotary_cos_
None, # rotary_sin_
None, # nnz_head_idx
inference_params.sequence_len_offset,
rotary_emb_dim, rotary_emb_base,
not rotary_emb_interleaved # neox_rotary_style
)
return rearrange(context, 'b h d -> b 1 h d')
class MHA(nn.Module):
"""Multi-head self-attention and cross-attention
"""
def __init__(self, embed_dim, num_heads, num_heads_kv=None, cross_attn=False,
qkv_proj_bias=True, out_proj_bias=True,
dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, dwconv=False,
rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None,
rotary_emb_interleaved=False, fused_bias_fc=False, use_flash_attn=False,
return_residual=False, checkpointing=False, device=None, dtype=None) -> None:
"""
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.embed_dim = embed_dim
self.cross_attn = cross_attn
self.causal = causal
self.layer_idx = layer_idx
self.dwconv = dwconv
self.rotary_emb_dim = rotary_emb_dim
self.use_flash_attn = use_flash_attn
self.return_residual = return_residual
self.checkpointing = checkpointing
self.num_heads = num_heads
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
assert self.num_heads % self.num_heads_kv == 0, "num_heads must be divisible by num_heads_kv"
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.head_dim = self.embed_dim // num_heads
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
kv_dim = 2 * self.head_dim * self.num_heads_kv
if self.rotary_emb_dim > 0:
assert not cross_attn, 'MHA with rotary embedding does not support cross-attention yet'
assert RotaryEmbedding is not None, 'rotary_emb is not installed'
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, base=rotary_emb_base,
scale_base=rotary_emb_scale_base,
interleaved=rotary_emb_interleaved, device=device)
if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed')
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
linear_resid_cls = (LinearResidual if not fused_bias_fc
else partial(FusedDense, return_residual=True))
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
if not self.cross_attn:
self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
else:
self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
if self.dwconv:
if self.num_heads_kv == self.num_heads:
self.dwconv_qkv = nn.Conv1d(qkv_dim, qkv_dim, kernel_size=3, padding=2,
groups=qkv_dim)
else:
self.dwconv_q = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, padding=2,
groups=embed_dim)
self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2,
groups=kv_dim)
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
attention_dropout=dropout)
self.inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale,
attention_dropout=dropout)
self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, fused_ft_kernel=True):
dtype = self.out_proj.weight.dtype if dtype is None else dtype
device = self.out_proj.weight.device
if not fused_ft_kernel:
return torch.empty(batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim,
dtype=dtype, device=device)
else:
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if dtype == torch.float32 else 8
assert self.head_dim % packsize == 0
k_cache = torch.empty(batch_size, self.num_heads_kv, self.head_dim // packsize,
max_seqlen, packsize, dtype=dtype, device=device)
v_cache = torch.empty(batch_size, self.num_heads_kv, max_seqlen, self.head_dim,
dtype=dtype, device=device)
return k_cache, v_cache
def _update_kv_cache(self, kv, inference_params):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
"""
assert not self.dwconv, 'Generation does not support dwconv yet'
assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor'
return _update_kv_cache(kv, inference_params, self.layer_idx)
def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None):
"""
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
q of shape (batch_size, 1, nheads, head_dim)
kv: (batch_size, 1, 2, nheads_kv, head_dim)
"""
rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
return _apply_rotary_single_query_attention(
qkv, inference_params, self.layer_idx, self.rotary_emb_dim, rotary_emb_base, kv=kv,
rotary_emb_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
)
def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None,
mixer_subset=None, inference_params=None, **kwargs):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
is the is the sum of the sequence lengths in the batch.
x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into x. Only applicable when using
FlashAttention.
max_seqlen: int. Maximum sequence length in the batch.
key_padding_mask: boolean mask, True means to keep, False means to mask out.
(batch, seqlen). Only applicable when not using FlashAttention.
mixer_subset: for cross-attention only. If not None, will take a subset of x
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
"""
if cu_seqlens is not None:
assert max_seqlen is not None
assert key_padding_mask is None
assert self.use_flash_attn
assert not self.dwconv
assert self.rotary_emb_dim == 0
if key_padding_mask is not None:
assert cu_seqlens is None
assert max_seqlen is None
assert not self.use_flash_attn
if inference_params is not None:
assert key_padding_mask is None
assert cu_seqlens is None and max_seqlen is None
assert not self.dwconv
kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, **kwargs}
if self.use_flash_attn else {'key_padding_mask': key_padding_mask, **kwargs})
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
if not self.cross_attn and self.num_heads_kv == self.num_heads:
assert x_kv is None and mixer_subset is None
if not self.return_residual:
qkv = self.Wqkv(x)
else:
qkv, x = self.Wqkv(x)
if self.dwconv:
qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2],
'b d s -> b s d').contiguous()
qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, d=self.head_dim)
if (inference_params is None or inference_params.sequence_len_offset == 0
or not inference_params.fused_ft_kernel):
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
if inference_params is None:
if not self.checkpointing:
context = self.inner_attn(qkv, **kwargs)
else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv,
**kwargs)
else:
q = qkv[:, :, 0]
kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal = None if inference_params.sequence_len_offset == 0 else False
context = self.inner_cross_attn(q, kv, causal=causal)
else:
context = self._apply_rotary_single_query_attention(qkv, inference_params)
else:
if self.cross_attn:
if not self.return_residual:
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
kv = self.Wkv(x_kv if x_kv is not None else x)
else:
if x_kv is not None:
kv, x_kv = self.Wkv(x_kv)
else:
kv, x = self.Wkv(x)
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
else:
assert self.num_heads_kv != self.num_heads
if not self.return_residual:
qkv = self.Wqkv(x)
else:
qkv, x = self.Wqkv(x)
q = qkv[..., :self.num_heads * self.head_dim]
kv = qkv[..., self.num_heads * self.head_dim:]
q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
kv = rearrange(kv, '... (two hkv d) -> ... two hkv d', two=2, d=self.head_dim)
if self.dwconv:
q = rearrange(self.dwconv_q(rearrange(q, 'b s d -> b d s'))[..., :-2],
'b d s -> b s d').contiguous()
kv = rearrange(self.dwconv_kv(rearrange(kv, 'b s d -> b d s'))[..., :-2],
'b d s -> b s d').contiguous()
if (inference_params is None or inference_params.sequence_len_offset == 0
or not inference_params.fused_ft_kernel):
if self.rotary_emb_dim > 0:
q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset)
if inference_params is None:
if not self.checkpointing:
context = self.inner_cross_attn(q, kv, **kwargs)
else:
context = torch.utils.checkpoint.checkpoint(self.inner_cross_attn, q, kv,
**kwargs)
else:
kv = self._update_kv_cache(kv, inference_params)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal = None if inference_params.sequence_len_offset == 0 else False
context = self.inner_cross_attn(q, kv, causal=causal)
else:
context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
out = self.out_proj(rearrange(context, '... h d -> ... (h d)'))
return out if not self.return_residual else (out, x)
class ParallelMHA(nn.Module):
"""Multi-head self-attention and cross-attention
"""
def __init__(self, embed_dim, num_heads, process_group, num_heads_kv=None,
qkv_proj_bias=True, out_proj_bias=True,
dropout=0.0, softmax_scale=None, causal=False, layer_idx=None,
rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None,
rotary_emb_interleaved=False, use_flash_attn=False, checkpointing=False,
sequence_parallel=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.embed_dim = embed_dim
self.causal = causal
self.layer_idx = layer_idx
self.rotary_emb_dim = rotary_emb_dim
self.use_flash_attn = use_flash_attn
self.checkpointing = checkpointing
self.process_group = process_group
self.world_size = process_group.size() if process_group is not None else 1
self.num_heads = num_heads
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
self.num_heads_per_rank = num_heads // self.world_size
self.num_heads_kv_per_rank = self.num_heads_kv // self.world_size
assert self.num_heads % self.num_heads_kv == 0, "num_heads must be divisible by num_heads_kv"
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
assert self.num_heads_kv % self.world_size == 0, "num_heads_kv must be divisible by world_size"
self.head_dim = self.embed_dim // num_heads
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
kv_dim = 2 * self.head_dim * self.num_heads_kv
if self.rotary_emb_dim > 0:
assert RotaryEmbedding is not None, 'rotary_emb is not installed'
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, base=rotary_emb_base,
scale_base=rotary_emb_scale_base,
interleaved=rotary_emb_interleaved, device=device)
if ColumnParallelLinear is None or RowParallelLinear is None:
raise ImportError('fused_dense is not installed')
self.Wqkv = ColumnParallelLinear(embed_dim, qkv_dim, process_group,
bias=qkv_proj_bias,
sequence_parallel=sequence_parallel, **factory_kwargs)
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
attention_dropout=dropout)
self.inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale,
attention_dropout=dropout)
self.out_proj = RowParallelLinear(embed_dim, embed_dim, process_group,
bias=out_proj_bias,
sequence_parallel=sequence_parallel, **factory_kwargs)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, fused_ft_kernel=True):
dtype = self.out_proj.weight.dtype if dtype is None else dtype
device = self.out_proj.weight.device
if not fused_ft_kernel:
return torch.empty(batch_size, max_seqlen, 2, self.num_heads_kv_per_rank,
self.head_dim, dtype=dtype, device=device)
else:
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if dtype == torch.float32 else 8
assert self.head_dim % packsize == 0
k_cache = torch.empty(batch_size, self.num_heads_kv_per_rank,
self.head_dim // packsize,
max_seqlen, packsize, dtype=dtype, device=device)
v_cache = torch.empty(batch_size, self.num_heads_kv_per_rank, max_seqlen,
self.head_dim, dtype=dtype, device=device)
return k_cache, v_cache
def _update_kv_cache(self, kv, inference_params):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
"""
assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor'
return _update_kv_cache(kv, inference_params, self.layer_idx)
def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None):
"""
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
q of shape (batch_size, 1, nheads, head_dim)
kv: (batch_size, 1, 2, nheads_kv, head_dim)
"""
rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
return _apply_rotary_single_query_attention(
qkv, inference_params, self.layer_idx, self.rotary_emb_dim, rotary_emb_base, kv=kv,
rotary_emb_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
)
def forward(self, x, seqlen=None, inference_params=None, **kwargs):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
split x during sequence parallel, we split the batch * seqlen dimension
(in case batch is small).
"""
qkv = self.Wqkv(x)
if seqlen is not None:
qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
if self.num_heads_kv == self.num_heads:
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, d=self.head_dim)
if (inference_params is None or inference_params.sequence_len_offset == 0
or not inference_params.fused_ft_kernel):
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
if inference_params is None:
if not self.checkpointing:
context = self.inner_attn(qkv, **kwargs)
else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
else:
q = qkv[:, :, 0]
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal = None if inference_params.sequence_len_offset == 0 else False
context = self.inner_cross_attn(q, kv, causal=causal)
else:
context = self._apply_rotary_single_query_attention(qkv, inference_params)
else:
q = rearrange(qkv[..., :self.num_heads_per_rank * self.head_dim],
"... (h d) -> ... h d", d=self.head_dim)
kv = rearrange(qkv[..., self.num_heads_per_rank * self.head_dim:],
"... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
if (inference_params is None or inference_params.sequence_len_offset == 0
or not inference_params.fused_ft_kernel):
if self.rotary_emb_dim > 0:
q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset)
if inference_params is None:
if not self.checkpointing:
context = self.inner_cross_attn(q, kv, **kwargs)
else:
context = torch.utils.checkpoint.checkpoint(self.inner_cross_attn, q, kv,
**kwargs)
else:
kv = self._update_kv_cache(kv, inference_params)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal = None if inference_params.sequence_len_offset == 0 else False
context = self.inner_cross_attn(q, kv, causal=causal)
else:
context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
context = rearrange(context, 'b s h d -> b s (h d)')
if seqlen is not None:
context = rearrange(context, 'b s d -> (b s) d')
out = self.out_proj(context)
return out

View File

@@ -0,0 +1,86 @@
# Copyright (c) 2022, Tri Dao.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed import ProcessGroup
try:
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
except ImportError:
ColumnParallelLinear, RowParallelLinear = None, None
try:
from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
except ImportError:
FusedMLP, ParallelFusedMLP = None, None
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
bias1=True, bias2=True, return_residual=False, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features * 4
self.return_residual = return_residual
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
self.activation = activation
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
def forward(self, x):
y = self.fc1(x)
y = self.activation(y)
y = self.fc2(y)
return y if not self.return_residual else (y, x)
class ParallelMLP(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
process_group: ProcessGroup = None, sequence_parallel=True,
bias1=True, bias2=True, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
assert ColumnParallelLinear is not None, "Need to install fused_dense"
assert RowParallelLinear is not None, "Need to install fused_dense"
out_features = out_features or in_features
hidden_features = hidden_features or in_features * 4
self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group, bias=bias1,
sequence_parallel=sequence_parallel, **factory_kwargs)
self.activation = activation
self.fc2 = RowParallelLinear(hidden_features, out_features, process_group, bias=bias2,
sequence_parallel=sequence_parallel, **factory_kwargs)
def forward(self, x):
y = self.fc1(x)
y = self.activation(y)
y = self.fc2(y)
return y
class GatedMlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.sigmoid,
bias1=True, bias2=True, multiple_of=256, return_residual=False,
device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or int(8 * in_features / 3)
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
self.return_residual = return_residual
self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
self.activation = activation
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias1, **factory_kwargs)
def forward(self, x):
y = self.fc1(x)
if self.activation == F.sigmoid: # Special case for GLU
y = F.glu(y, dim=-1)
else:
y, gate = y.chunk(2, dim=-1)
y = y * self.activation(gate)
y = self.fc2(y)
return y if not self.return_residual else (y, x)

View File

@@ -0,0 +1,99 @@
# Copied from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/model/layers/activations.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.script
def bias_gelu(y, bias):
x = bias + y
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_back(g, y, bias):
"""Assume that y has shape (B, D) and bias has shape (D)
"""
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
grad_y = ff * g
return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(input, bias)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_back(grad_output, input, bias)
return tmp, tmp
bias_gelu_impl = GeLUFunction.apply
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.script
def gelu_fwd(x):
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def gelu_bwd(g, x):
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return (ff * g).to(dtype=x.dtype)
class FastGeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input):
ctx.save_for_backward(input)
return gelu_fwd(input)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
tmp = gelu_bwd(grad_output, input)
return tmp
fast_gelu_impl = FastGeLUFunction.apply
@torch.jit.script
def relu_bwd(g, x):
return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype)
@torch.jit.script
def sqrelu_fwd(x):
r = F.relu(x)
return (r * r).to(dtype=x.dtype)
@torch.jit.script
def sqrelu_bwd(g, x):
return (2.0 * g * F.relu(x)).to(dtype=x.dtype)

View File

@@ -0,0 +1,527 @@
# Copyright (c) 2023, Tri Dao.
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
# We make it work with pytorch amp and with bfloat16.
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
from typing import Optional
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.cuda.amp import custom_bwd, custom_fwd
# import fused_dense_cuda # from apex
import fused_dense_lib as fused_dense_cuda
from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_fwd, sqrelu_bwd
from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all_reduce_raw
from flash_attn.utils.distributed import reduce_scatter, all_reduce
class FusedDenseFunc(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, x, weight, bias, return_residual=False, process_group=None,
sequence_parallel=True):
"""
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
"""
ctx.compute_weight_gradient = weight.requires_grad
ctx.return_residual = return_residual
ctx.process_group = process_group
ctx.sequence_parallel = sequence_parallel
if torch.is_autocast_enabled():
x = x.to(dtype=torch.get_autocast_gpu_dtype())
x = x.contiguous()
if process_group is not None and sequence_parallel:
# We want to kick off the all_gather early, before weight dtype conversion
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
else:
total_x = x
if torch.is_autocast_enabled():
weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
weight = weight.contiguous()
if process_group is not None and sequence_parallel:
handle_x.wait()
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
batch_dim = batch_shape.numel()
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
if min(batch_dim, n, *weight.shape) > 65535 * 32:
raise RuntimeError('fused_dense only supports matrix dims <= 2M')
output = F.linear(total_x, weight, bias)
if ctx.compute_weight_gradient:
ctx.save_for_backward(x, weight)
else:
ctx.save_for_backward(weight)
return output if not return_residual else (output, x)
@staticmethod
@custom_bwd
def backward(ctx, grad_output, *args):
grad_output = grad_output.contiguous()
if ctx.return_residual:
grad_input, = args
grad_input = grad_input.contiguous()
process_group = ctx.process_group
sequence_parallel = ctx.sequence_parallel
if ctx.compute_weight_gradient:
x, weight = ctx.saved_tensors
if process_group is not None and sequence_parallel:
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
else:
total_x = x
else:
weight, = ctx.saved_tensors
total_x = None
batch_shape = grad_output.shape[:-1]
batch_dim = batch_shape.numel()
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
if ctx.needs_input_grad[0]:
if not ctx.return_residual:
grad_input = F.linear(grad_output, weight.t())
else:
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
grad_output, weight)
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
if process_group is not None:
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
else:
grad_input = None
if ctx.needs_input_grad[1]:
assert ctx.compute_weight_gradient
if process_group is not None and sequence_parallel:
handle_x.wait()
grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
)
else:
grad_weight = None
grad_bias = grad_output if ctx.needs_input_grad[2] else None
if process_group is not None and ctx.needs_input_grad[0]:
handle_grad_input.wait()
return grad_input, grad_weight, grad_bias, None, None, None
def fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None,
return_residual: bool = False, process_group: Optional[ProcessGroup] = None,
sequence_parallel: bool = True):
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group,
sequence_parallel)
else:
assert process_group is None
out = F.linear(x, weight, bias)
return out if not return_residual else (out, x)
class FusedDense(nn.Linear):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
return_residual: bool = False, device=None, dtype=None) -> None:
super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
self.return_residual = return_residual
def forward(self, x, process_group=None):
"""
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul.
"""
return fused_dense_func(x, self.weight, self.bias, return_residual=self.return_residual,
process_group=process_group)
class ColumnParallelLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup,
bias: bool = True, sequence_parallel=True, device=None, dtype=None) -> None:
world_size = torch.distributed.get_world_size(process_group)
if out_features % world_size != 0:
raise ValueError(f'out_features ({out_features}) must be divisible by '
f'world_size ({world_size})')
super().__init__(in_features, out_features // world_size, bias=bias,
device=device, dtype=dtype)
self.process_group = process_group
self.sequence_parallel = sequence_parallel
def forward(self, x):
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
# If not, then the input is already gathered.
return fused_dense_func(x, self.weight, self.bias, process_group=self.process_group,
sequence_parallel=self.sequence_parallel)
class RowParallelLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup,
bias: bool = True, sequence_parallel=True, device=None, dtype=None) -> None:
world_size = torch.distributed.get_world_size(process_group)
rank = torch.distributed.get_rank(process_group)
if in_features % world_size != 0:
raise ValueError(f'in_features ({in_features}) must be divisible by '
f'world_size ({world_size})')
# Only rank 0 will have bias
super().__init__(in_features // world_size, out_features, bias=bias and rank == 0,
device=device, dtype=dtype)
self.process_group = process_group
self.sequence_parallel = sequence_parallel
def forward(self, x):
"""
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
a reduce_scatter of the result.
"""
out = fused_dense_func(x, self.weight, self.bias)
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
return reduce_fn(out, self.process_group)
class FusedMLPFunc(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, x, weight1, bias1, weight2, bias2, activation='gelu_approx', save_pre_act=True,
return_residual=False, checkpoint_lvl=0, heuristic=0, process_group=None,
sequence_parallel=True):
"""
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
with sequence parallelism: we do an all_gather of x before doing the matmul.
If sequence_parallel=False, then the input is already gathered.
checkpoint_lvl:
0: no recomputation in the bwd
1: recompute gelu_out / relu_out in the bwd
2: recompute pre_act and gelu_out / relu_out in the bwd
"""
assert -1 <= heuristic <= 4
assert activation in ['gelu_approx', 'relu', 'sqrelu']
if activation == 'sqrelu':
assert heuristic == -1
if not save_pre_act:
checkpoint_lvl = 2
assert checkpoint_lvl in [0, 1, 2]
ctx.return_residual = return_residual
ctx.process_group = process_group
ctx.sequence_parallel = sequence_parallel
ctx.checkpoint_lvl = checkpoint_lvl
ctx.activation = activation
ctx.heuristic = heuristic
if torch.is_autocast_enabled():
x = x.to(dtype=torch.get_autocast_gpu_dtype())
x = x.contiguous()
if process_group is not None and sequence_parallel:
# We want to kick off the all_gather early, before weight dtype conversion
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
else:
total_x = x
if torch.is_autocast_enabled():
dtype = torch.get_autocast_gpu_dtype()
weight1, weight2 = [a.to(dtype=dtype) for a in [weight1, weight2]]
bias1 = bias1.to(dtype=dtype) if bias1 is not None else None
bias2 = bias2.to(dtype=dtype) if bias2 is not None else None
weight1 = weight1.contiguous()
bias1 = bias1.contiguous() if bias1 is not None else None
weight2 = weight2.contiguous()
bias2 = bias2.contiguous() if bias2 is not None else None
if process_group is not None and sequence_parallel:
handle_x.wait()
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
batch_dim = batch_shape.numel()
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32:
raise RuntimeError('fused_dense only supports matrix dims <= 2M')
if heuristic == -1:
pre_act = F.linear(total_x, weight1, bias1)
activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
else (sqrelu_fwd if activation == 'sqrelu' else F.relu))
with torch.jit.fuser('fuser2'):
output1 = activation_fn(pre_act)
# This is before adding bias1
# pre_act = F.linear(total_x.reshape(batch_dim, n), weight1)
# with torch.jit.fuser('fuser2'):
# output1 = bias_gelu(pre_act, bias1)
else:
is_gelu = activation == 'gelu_approx'
output1, *rest = fused_dense_cuda.linear_act_forward(
total_x.reshape(batch_dim, n), weight1, bias1, is_gelu, save_pre_act, heuristic
)
if save_pre_act:
pre_act = rest[0]
output2 = F.linear(output1, weight2, bias2)
if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == 'relu'):
# For RELU the pre_act is very small (just a bit-mask) so we just save it
ctx.save_for_backward(x, weight1, weight2, pre_act, output1)
elif checkpoint_lvl == 1:
ctx.save_for_backward(x, weight1, weight2, pre_act)
elif checkpoint_lvl == 2:
ctx.save_for_backward(x, weight1, weight2, bias1)
output2 = output2.reshape(*batch_shape, output2.shape[-1])
return output2 if not return_residual else (output2, x)
@staticmethod
@custom_bwd
def backward(ctx, grad_output, *args):
grad_output = grad_output.contiguous()
checkpoint_lvl = ctx.checkpoint_lvl
activation = ctx.activation
activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
else (sqrelu_fwd if activation == 'sqrelu' else F.relu))
if ctx.return_residual:
grad_input, = args
grad_input = grad_input.contiguous()
process_group = ctx.process_group
sequence_parallel = ctx.sequence_parallel
x, weight1, weight2, *rest = ctx.saved_tensors
if process_group is None or not sequence_parallel:
total_x = x
batch_shape = grad_output.shape[:-1]
batch_dim = batch_shape.numel()
if checkpoint_lvl in [0, 1]:
if process_group is not None and sequence_parallel:
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == 'relu'):
pre_act, output1 = rest
elif checkpoint_lvl == 1:
pre_act, = rest
with torch.jit.fuser('fuser2'):
output1 = activation_fn(pre_act)
elif checkpoint_lvl == 2:
bias1, = rest
if process_group is not None and sequence_parallel:
total_x, _ = all_gather_raw(x, process_group)
if ctx.heuristic == -1:
pre_act = F.linear(total_x, weight1, bias1)
with torch.jit.fuser('fuser2'):
output1 = activation_fn(pre_act)
else:
output1, pre_act = fused_dense_cuda.linear_act_forward(
total_x.reshape(batch_dim, total_x.shape[-1]), weight1, bias1,
activation == 'gelu_approx', True, ctx.heuristic
)
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
output1 = output1.reshape(batch_dim, output1.shape[-1])
pre_act = pre_act.reshape(batch_dim, pre_act.shape[-1])
if ctx.needs_input_grad[3]:
grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(
output1, grad_output, ctx.needs_input_grad[4]
)
else:
grad_weight2 = None
grad_bias2 = grad_output if ctx.needs_input_grad[4] else None
if ctx.heuristic == -1:
# grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act)
grad_output1 = F.linear(grad_output, weight2.t())
activation_grad_fn = (gelu_bwd if activation == 'gelu_approx'
else (sqrelu_bwd if activation == 'sqrelu' else relu_bwd))
with torch.jit.fuser('fuser2'):
grad_pre_act = activation_grad_fn(grad_output1, pre_act)
else:
# The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't
# just compute gelu/relu grad
grad_pre_act, grad_bias1 = fused_dense_cuda.bias_act_linear_dgrad_bgrad(
weight2, grad_output, pre_act, activation == 'gelu_approx', ctx.heuristic
)
if not ctx.needs_input_grad[2]:
grad_bias1 = None
if ctx.needs_input_grad[0]:
if not ctx.return_residual:
grad_input = F.linear(grad_pre_act, weight1.t())
else:
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
grad_pre_act, weight1)
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
if process_group is not None:
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
else:
grad_input = None
if ctx.heuristic == -1:
if ctx.needs_input_grad[1]:
if process_group is not None and sequence_parallel:
handle_x.wait()
grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad(
total_x.reshape(batch_dim, total_x.shape[-1]), grad_pre_act,
ctx.needs_input_grad[2]
)
else:
grad_weight1 = None
grad_bias1 = grad_pre_act if ctx.needs_input_grad[2] else None
else:
if ctx.needs_input_grad[1]:
if process_group is not None and sequence_parallel:
handle_x.wait()
grad_weight1 = F.linear(grad_pre_act.t(),
total_x.reshape(batch_dim, total_x.shape[-1]).t())
else:
grad_weight1 = None
if process_group is not None and ctx.needs_input_grad[0]:
handle_grad_input.wait()
return (grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2,
None, None, None, None, None, None, None)
def fused_mlp_func(
x: Tensor, weight1: Tensor, weight2: Tensor, bias1: Optional[Tensor] = None,
bias2: Optional[Tensor] = None, activation: str = 'gelu_approx',
save_pre_act: bool = True, return_residual: bool = False,
checkpoint_lvl: int = 0, heuristic: int = 0,
process_group: Optional[ProcessGroup] = None,
sequence_parallel: bool = True
):
assert activation in ['gelu_approx', 'relu', 'sqrelu']
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
# If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu)
dim_eligible = not save_pre_act or (x.shape[-1] % (128 if activation == 'relu' else 8) == 0)
if (x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda)
and (bias2 is None or bias2.is_cuda) and dtype_eligible and dim_eligible):
return FusedMLPFunc.apply(
x, weight1, bias1, weight2, bias2, activation, save_pre_act, return_residual,
checkpoint_lvl, heuristic, process_group, sequence_parallel
)
else:
assert process_group is None
pre_act = F.linear(x, weight1, bias1)
activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
else partial(F.relu, inplace=True))
output1 = activation_fn(pre_act)
output2 = F.linear(output1, weight2, bias2)
return output2 if not return_residual else (output2, x)
class FusedMLP(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, bias1=True,
bias2=True, activation='gelu_approx', return_residual=False,
checkpoint_lvl=0, heuristic='auto', device=None, dtype=None):
"""
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul, gelu, then matmul.
Finally we do a reduce_scatter of the output.
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute pre_act and gelu_out in the bwd
heuristic:
-1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu
'auto': heuristic will be picked automatically:
For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
For H100, we set heuristic=-1 for both fp16 and bf16 as the fused cuBlasLt implementation
is slower than the unfused version.
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
assert checkpoint_lvl in [0, 1, 2]
assert activation in ['gelu_approx', 'relu', 'sqrelu']
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features * 4
self.activation = activation
self.return_residual = return_residual
self.checkpoint_lvl = checkpoint_lvl
self.heuristic = heuristic if activation != 'sqrelu' else -1
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
def forward(self, x, process_group=None):
dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()
if self.heuristic == 'auto':
if self.activation == 'gelu_approx':
if torch.cuda.get_device_capability('cuda') == (9, 0):
heuristic = -1
else:
cuda_ver = tuple(map(int, torch.version.cuda.split('.')))
heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
else:
heuristic = 0
else:
heuristic = self.heuristic
out = fused_mlp_func(
x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias,
activation=self.activation, save_pre_act=self.training,
return_residual=self.return_residual, checkpoint_lvl=self.checkpoint_lvl,
heuristic=heuristic, process_group=process_group
)
if self.return_residual:
out, x = out
if process_group is not None:
out = reduce_scatter(out, process_group)
return out if not self.return_residual else (out, x)
class ParallelFusedMLP(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None,
activation='gelu_approx', process_group: ProcessGroup = None,
bias1=True, bias2=True, sequence_parallel=True, checkpoint_lvl=0, heuristic='auto',
device=None, dtype=None):
"""
process_group is required. We're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul, gelu, then matmul.
Finally we do a reduce_scatter of the output.
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute pre_act and gelu_out in the bwd
heuristic:
-1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu
'auto': heuristic will be picked automatically:
For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
"""
assert checkpoint_lvl in [0, 1, 2]
assert activation in ['gelu_approx', 'relu', 'sqrelu']
assert process_group is not None
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features * 4
self.activation = activation
self.process_group = process_group
self.sequence_parallel = sequence_parallel
self.checkpoint_lvl = checkpoint_lvl
self.heuristic = heuristic if activation != 'sqrelu' else -1
self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group,
bias=bias1, **factory_kwargs)
self.fc2 = RowParallelLinear(hidden_features, out_features, process_group,
bias=bias2, **factory_kwargs)
def forward(self, x):
dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()
if self.heuristic == 'auto':
if self.activation == 'gelu_approx':
cuda_ver = tuple(map(int, torch.version.cuda.split('.')))
heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
else:
heuristic = 0
else:
heuristic = self.heuristic
out = fused_mlp_func(
x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias,
activation=self.activation, save_pre_act=self.training,
checkpoint_lvl=self.checkpoint_lvl, heuristic=heuristic,
process_group=self.process_group,
sequence_parallel=self.sequence_parallel
)
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
return reduce_fn(out, self.process_group)

View File

@@ -0,0 +1,375 @@
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
import torch
from torch.nn import init
import dropout_layer_norm
def maybe_align(x, alignment_in_bytes=16):
"""Assume that x already has last dim divisible by alignment_in_bytes
"""
# TD [2023-07-04] I'm not 100% sure that clone will align the memory
# https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone()
def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscale, dropout_p,
epsilon, residual_in_fp32=False, is_rms_norm=False):
""" Assume that arguments are contiguous and aligned to 16 bytes
"""
hidden_size = gamma.numel()
x0mat = x0.view((-1, hidden_size))
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
rowscale = rowscale.view(-1) if rowscale is not None else None
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
x0mat, residualmat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon,
1.0, 0, None, residual_in_fp32, is_rms_norm
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale,
dropout_p, has_residual, is_rms_norm=False):
""" Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
x0 must not be None if we have colscale.
"""
hidden_size = gamma.numel()
xmat = x.view((-1, hidden_size))
dzmat = dz.view(xmat.shape)
dxmat = dx.view(xmat.shape) if dx is not None else None
x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
rowscale = rowscale.view(-1) if rowscale is not None else None
if colscale is not None:
assert x0 is not None, 'x0 is required to compute the gradient of colscale'
dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None,
dropout_p, 1.0, 0, has_residual, is_rms_norm
)
# dresidualmat is None if not has_residual
if colscale is None:
return dx0mat, dresidualmat, dgamma, dbeta
else:
dcolscale = rest[0]
return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale, x0_subset,
out_subset, dropout_p, epsilon, rowscale_const,
out_numrows, residual_in_fp32=False, is_rms_norm=False):
""" Assume that arguments are contiguous and aligned to 16 bytes
"""
hidden_size = gamma.numel()
x0mat = x0.view((-1, hidden_size))
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
x0_subset = x0_subset.view(-1) if x0_subset is not None else None
out_subset = out_subset.view(-1) if out_subset is not None else None
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
x0mat, residualmat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, None, residual_in_fp32, is_rms_norm
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale,
x0_subset, out_subset, dropout_p, rowscale_const,
x0_numrows, has_residual, is_rms_norm=False):
""" Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
x0 must not be None if we have colscale.
"""
hidden_size = gamma.numel()
xmat = x.view((-1, hidden_size))
dzmat = dz.view(-1, hidden_size)
dxmat = dx.view(xmat.shape) if dx is not None else None
x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
x0_subset = x0_subset.view(-1) if x0_subset is not None else None
out_subset = out_subset.view(-1) if out_subset is not None else None
if colscale is not None:
assert x0 is not None, 'x0 is required to compute the gradient of colscale'
dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset,
dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm
)
# dresidualmat is None if not has_residual
if colscale is None:
return dx0mat, dresidualmat, dgamma, dbeta
else:
dcolscale = rest[0]
return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
def _dropout_add_layer_norm_parallel_residual_forward(
x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p,
epsilon, residual_in_fp32=False, is_rms_norm=False
):
""" Assume that arguments are contiguous and aligned to 16 bytes
"""
hidden_size = gamma0.numel()
x0mat = x0.view((-1, hidden_size))
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd(
x0mat, x1mat, residualmat, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
None, residual_in_fp32, is_rms_norm
)
# dmask0 and dmask1 are None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma
def _dropout_add_layer_norm_parallel_residual_backward(
dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1,
dropout_p, has_x1, has_residual, is_rms_norm=False
):
""" Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
"""
hidden_size = gamma0.numel()
xmat = x.view((-1, hidden_size))
dz0mat = dz0.view(xmat.shape)
dz1mat = dz1.view(xmat.shape) if dz1 is not None else None
dxmat = dx.view(xmat.shape) if dx is not None else None
dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1, *rest = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd(
dz0mat, dz1mat, dxmat, xmat, dmask0, dmask1, mu, rsigma, gamma0, gamma1,
dropout_p, has_x1, has_residual, is_rms_norm
)
# dresidualmat is None if not has_residual
return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1
class DropoutAddLayerNormFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
x0 = maybe_align(x0.contiguous(), 16)
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
gamma = maybe_align(gamma.contiguous(), 16)
beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
residual_in_fp32, is_rms_norm
)
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved = x0 if colscale is not None else None
ctx.save_for_backward(xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale)
ctx.prenorm = prenorm
ctx.dropout_p = dropout_p
ctx.has_residual = residual is not None
ctx.is_rms_norm = is_rms_norm
ctx.has_beta = beta is not None
if not return_dmask:
return (zmat.view(x0.shape) if not prenorm
else (zmat.view(x0.shape), xmat.view(x0.shape)))
else:
dmask = (dmask.view(x0.shape) if dropout_p > 0.
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
ctx.mark_non_differentiable(dmask)
return ((zmat.view(x0.shape), dmask) if not prenorm
else (zmat.view(x0.shape), xmat.view(x0.shape), dmask))
@staticmethod
def backward(ctx, dz, *args):
# assert dz.is_contiguous()
dz = maybe_align(dz.contiguous(), 16) # this happens!
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
# x0 is None if colscale is None
dropout_p = ctx.dropout_p
has_residual = ctx.has_residual
dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual,
ctx.is_rms_norm
)
dx0 = dx0mat.view(x.shape)
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
dcolscale = rest[0] if colscale is not None else None
return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None,
None, None, None, None, None)
class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32=False,
prenorm=False, is_rms_norm=False, return_dmask=False):
x0 = maybe_align(x0.contiguous(), 16)
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
gamma = maybe_align(gamma.contiguous(), 16)
beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32, is_rms_norm
)
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved = x0 if colscale is not None else None
x_shape = (-1, *x0.shape[1:])
ctx.save_for_backward(xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale,
x0_subset, out_subset)
ctx.prenorm = prenorm
ctx.dropout_p = dropout_p
ctx.rowscale_const = rowscale_const
ctx.x0_numrows = x0.shape[:-1].numel()
ctx.has_residual = residual is not None
ctx.is_rms_norm = is_rms_norm
ctx.has_beta = beta is not None
z_shape = (-1, *x0.shape[1:])
if not return_dmask:
return (zmat.view(z_shape) if not prenorm
else (zmat.view(z_shape), xmat.view(x0.shape)))
else:
z = zmat.view(z_shape)
dmask = (dmask.view(x0.shape) if dropout_p > 0.
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
ctx.mark_non_differentiable(dmask)
return ((z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask))
@staticmethod
def backward(ctx, dz, *args):
# assert dz.is_contiguous()
dz = maybe_align(dz.contiguous(), 16) # this happens!
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors
# x0 is None if colscale is None
dropout_p = ctx.dropout_p
has_residual = ctx.has_residual
dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p,
ctx.rowscale_const, ctx.x0_numrows, has_residual, ctx.is_rms_norm
)
dx0 = dx0mat.view(-1, *x.shape[1:])
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
dcolscale = rest[0] if colscale is not None else None
return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None,
None, None, None, None, None, None, None, None)
class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
x0 = maybe_align(x0.contiguous(), 16)
x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
gamma0 = maybe_align(gamma0.contiguous(), 16)
beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None
gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None
beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None
z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = _dropout_add_layer_norm_parallel_residual_forward(
x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
residual_in_fp32, is_rms_norm
)
ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma)
ctx.prenorm = prenorm
ctx.dropout_p = dropout_p
ctx.has_x1 = x1 is not None
ctx.has_residual = residual is not None
ctx.is_rms_norm = is_rms_norm
ctx.has_beta = beta0 is not None
z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None)
if not return_dmask:
return z if not prenorm else (*z, xmat.view(x0.shape))
else:
dmask0 = (dmask0.view(x0.shape) if dropout_p > 0.
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
dmask1 = (dmask1.view(x0.shape) if dropout_p > 0. and x1 is not None
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
ctx.mark_non_differentiable(dmask0)
ctx.mark_non_differentiable(dmask1)
return (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1)
@staticmethod
def backward(ctx, dz0, dz1, *args):
dz0 = maybe_align(dz0.contiguous(), 16) # this happens!
dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors
dropout_p = ctx.dropout_p
has_x1 = ctx.has_x1
has_residual = ctx.has_residual
dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 = _dropout_add_layer_norm_parallel_residual_backward(
dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1, dropout_p, has_x1,
has_residual, ctx.is_rms_norm
)
dx0 = dx0mat.view(x.shape)
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
return (dx0, dx1, dresidual, dgamma0, dbeta0 if ctx.has_beta else None, dgamma1,
dbeta1 if ctx.has_beta else None, None, None, None, None, None, None)
def layer_norm(x, weight, bias, epsilon):
return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
def dropout_add_layer_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
layerscale=None, prenorm=False, residual_in_fp32=False,
return_dropout_mask=False):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return DropoutAddLayerNormFn.apply(
x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
False, return_dropout_mask
)
def dropout_add_layer_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
x0_subset=None, out_subset=None, rowscale_const=1.0,
out_numrows=0, prenorm=False, residual_in_fp32=False,
return_dropout_mask=False):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return DropoutAddLayerNormSubsetFn.apply(
x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32, prenorm, False, return_dropout_mask
)
def dropout_add_layer_norm_parallel_residual(
x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, prenorm=False,
residual_in_fp32=False, return_dropout_mask=False
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return DropoutAddLayerNormParallelResidualFn.apply(
x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm,
False, return_dropout_mask
)
class DropoutAddLayerNorm(torch.nn.Module):
def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.prenorm = prenorm
self.p = p
self.eps = eps
self.residual_in_fp32 = residual_in_fp32
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.reset_parameters()
def reset_parameters(self):
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, x0, residual=None):
return dropout_add_layer_norm(x0, residual, self.weight, self.bias,
self.p if self.training else 0.0, self.eps,
prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)

View File

@@ -0,0 +1,89 @@
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
import torch
from torch.nn import init
from flash_attn.ops.layer_norm import DropoutAddLayerNormFn, DropoutAddLayerNormSubsetFn
from flash_attn.ops.layer_norm import DropoutAddLayerNormParallelResidualFn
def rms_norm(x, weight, epsilon):
return DropoutAddLayerNormFn.apply(x, None, weight, None, None, None, 0.0, epsilon, False,
False, True)
def dropout_add_rms_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
layerscale=None, prenorm=False, residual_in_fp32=False,
return_dropout_mask=False):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return DropoutAddLayerNormFn.apply(
x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
True, return_dropout_mask
)
def dropout_add_rms_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
x0_subset=None, out_subset=None, rowscale_const=1.0,
out_numrows=0, prenorm=False, residual_in_fp32=False,
return_dropout_mask=False):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return DropoutAddLayerNormSubsetFn.apply(
x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32, prenorm, True, return_dropout_mask
)
def dropout_add_rms_norm_parallel_residual(
x0, x1, residual, weight0, bias0, weight1, bias1,
dropout_p, epsilon, prenorm=False, residual_in_fp32=False, return_dropout_mask=False
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return DropoutAddLayerNormParallelResidualFn.apply(
x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm,
True, return_dropout_mask
)
class RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
init.ones_(self.weight)
def forward(self, x):
return rms_norm(x, self.weight, self.eps)
class DropoutAddRMSNorm(torch.nn.Module):
def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.prenorm = prenorm
self.p = p
self.eps = eps
self.residual_in_fp32 = residual_in_fp32
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
init.ones_(self.weight)
def forward(self, x0, residual=None):
return dropout_add_rms_norm(x0, residual, self.weight, None,
self.p if self.training else 0.0, self.eps,
prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)

View File

@@ -0,0 +1,146 @@
# Copyright (c) 2022, Tri Dao.
""" Useful functions for writing test code. """
import torch
import torch.utils.benchmark as benchmark
def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, amp=False,
amp_dtype=torch.float16, **kwinputs):
""" Use Pytorch Benchmark on the forward pass of an arbitrary function. """
if verbose:
print(desc, '- Forward pass')
def fn_amp(*inputs, **kwinputs):
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
fn(*inputs, **kwinputs)
for _ in range(repeats): # warmup
fn_amp(*inputs, **kwinputs)
t = benchmark.Timer(
stmt='fn_amp(*inputs, **kwinputs)',
globals={'fn_amp': fn_amp, 'inputs': inputs, 'kwinputs': kwinputs},
num_threads=torch.get_num_threads(),
)
m = t.timeit(repeats)
if verbose:
print(m)
return t, m
def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
amp_dtype=torch.float16, **kwinputs):
""" Use Pytorch Benchmark on the backward pass of an arbitrary function. """
if verbose:
print(desc, '- Backward pass')
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
y = fn(*inputs, **kwinputs)
if type(y) is tuple:
y = y[0]
if grad is None:
grad = torch.randn_like(y)
else:
if grad.shape != y.shape:
raise RuntimeError('Grad shape does not match output shape')
for _ in range(repeats): # warmup
y.backward(grad, retain_graph=True)
t = benchmark.Timer(
stmt='y.backward(grad, retain_graph=True)',
globals={'y': y, 'grad': grad},
num_threads=torch.get_num_threads(),
)
m = t.timeit(repeats)
if verbose:
print(m)
return t, m
def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
amp_dtype=torch.float16, **kwinputs):
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
if verbose:
print(desc, '- Forward + Backward pass')
def f(grad, *inputs, **kwinputs):
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
y = fn(*inputs, **kwinputs)
if type(y) is tuple:
y = y[0]
if grad is None:
grad = torch.randn_like(y)
else:
if grad.shape != y.shape:
raise RuntimeError('Grad shape does not match output shape')
y.backward(grad, retain_graph=True)
for _ in range(repeats): # warmup
f(grad, *inputs, **kwinputs)
t = benchmark.Timer(
stmt='f(grad, *inputs, **kwinputs)',
globals={'f': f, 'fn': fn, 'inputs': inputs, 'grad': grad, 'kwinputs': kwinputs},
num_threads=torch.get_num_threads(),
)
m = t.timeit(repeats)
if verbose:
print(m)
return t, m
def benchmark_all(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
amp_dtype=torch.float16, **kwinputs):
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
return (
benchmark_forward(fn, *inputs, repeats=repeats, desc=desc, verbose=verbose,
amp=amp, amp_dtype=amp_dtype, **kwinputs),
benchmark_backward(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose,
amp=amp, amp_dtype=amp_dtype, **kwinputs),
benchmark_combined(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose,
amp=amp, amp_dtype=amp_dtype, **kwinputs),
)
def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False,
amp_dtype=torch.float16, cpu=False, verbose=True, **kwinputs):
""" Wrap benchmark functions in Pytorch profiler to see CUDA information. """
if backward:
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
g = torch.randn_like(fn(*inputs, **kwinputs))
for _ in range(30): # Warm up
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
if backward:
for x in inputs:
if isinstance(x, torch.Tensor):
x.grad = None
# fn(*inputs, **kwinputs) if not backward else fn(*inputs, **kwinputs).backward(g)
out = fn(*inputs, **kwinputs)
# Backward should be done outside autocast
if backward:
out.backward(g)
activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [torch.profiler.ProfilerActivity.CUDA]
with torch.profiler.profile(
activities=activities,
record_shapes=True,
# profile_memory=True,
with_stack=True,
) as prof:
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
if backward:
for x in inputs:
if isinstance(x, torch.Tensor):
x.grad = None
out = fn(*inputs, **kwinputs)
if backward: out.backward(g)
if verbose:
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
print(prof.key_averages().table(row_limit=50))
if trace_filename is not None:
prof.export_chrome_trace(trace_filename)
def benchmark_memory(fn, *inputs, desc='', verbose=True, **kwinputs):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
fn(*inputs, **kwinputs)
torch.cuda.synchronize()
mem = torch.cuda.max_memory_allocated() / ((2 ** 20) * 1000)
if verbose:
print(f'{desc} max memory: {mem}GB')
torch.cuda.empty_cache()
return mem

View File

@@ -0,0 +1,127 @@
from typing import Optional
import torch
from torch import Tensor
from torch.distributed import ProcessGroup
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 4 lines are for backward compatibility with
# older PyTorch.
if "all_gather_into_tensor" not in dir(torch.distributed):
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
if "reduce_scatter_tensor" not in dir(torch.distributed):
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
# Raw operation, does not support autograd, but does support async
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
world_size = torch.distributed.get_world_size(process_group)
output = torch.empty(world_size * input_.shape[0], *input_.shape[1:],
dtype=input_.dtype, device=input_.device)
handle = torch.distributed.all_gather_into_tensor(output, input_.contiguous(),
group=process_group, async_op=async_op)
return output, handle
# Raw operation, does not support autograd, but does support async
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
world_size = torch.distributed.get_world_size(process_group)
assert input_.shape[0] % world_size == 0
output = torch.empty(input_.shape[0] // world_size, *input_.shape[1:],
dtype=input_.dtype, device=input_.device)
handle = torch.distributed.reduce_scatter_tensor(output, input_.contiguous(),
group=process_group,
async_op=async_op)
return output, handle
# Raw operation, does not support autograd, but does support async
def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
input_ = input_.contiguous()
handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
return input_, handle
class AllGatherFunc(torch.autograd.Function):
"""Gather the input from sequence parallel region and concatenate."""
@staticmethod
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
ctx.process_group = process_group
output, _ = all_gather_raw(input_, process_group)
return output
@staticmethod
def backward(ctx, grad_output: Tensor):
grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)
return grad_input, None
# Supports autograd, but does not support async
all_gather = AllGatherFunc.apply
class ReduceScatterFunc(torch.autograd.Function):
"""Reduce scatter the input from the sequence parallel region and concatenate."""
@staticmethod
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
ctx.process_group = process_group
output, _ = reduce_scatter_raw(input_, process_group)
return output
@staticmethod
def backward(ctx, grad_output: Tensor):
grad_input, _ = all_gather_raw(grad_output, ctx.process_group)
return grad_input, None
# Supports autograd, but does not support async
reduce_scatter = ReduceScatterFunc.apply
class AllReduceFunc(torch.autograd.Function):
"""Gather the input from sequence parallel region and concatenate."""
@staticmethod
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
ctx.process_group = process_group
output, _ = all_reduce_raw(input_, process_group)
return output
@staticmethod
def backward(ctx, grad_output: Tensor):
return grad_output, None
# Supports autograd, but does not support async
all_reduce = AllReduceFunc.apply
def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
# We want to iterate over parameters with _shared_params=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
pamams_shared = {name: p for name, p in model.named_parameters()
if getattr(p, '_shared_params', False)}
for _, p in sorted(pamams_shared.items()):
with torch.no_grad():
# Broadcast needs src to be global rank, not group rank
torch.distributed.broadcast(
p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
)
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
# We want to iterate over parameters with _sequence_parallel=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
params_seqparallel = {name: p for name, p in model.named_parameters()
if getattr(p, '_sequence_parallel', False)}
grads = [p.grad for _, p in sorted(params_seqparallel.items())]
if grads:
with torch.no_grad():
coalesced = torch._utils._flatten_dense_tensors(grads)
torch.distributed.all_reduce(coalesced, group=process_group)
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)

View File

@@ -0,0 +1,302 @@
# Copyright (c) 2023, Tri Dao.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
from typing import Optional, Union, Sequence, Callable
import gc
import time
from dataclasses import dataclass, field
from collections import namedtuple
import torch
from torch import Tensor
from torch.profiler import profile, record_function, ProfilerActivity
from einops import rearrange
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
@dataclass
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_sequence_len: int
max_batch_size: int
sequence_len_offset: int = 0
batch_size_offset: int = 0
key_value_memory_dict: dict = field(default_factory=dict)
fused_ft_kernel: bool = False
lengths_per_sample: Optional[Tensor] = None
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
def modify_logits_for_top_p_filtering(logits, top_p):
"""Set the logits for none top-p values to -inf."""
if top_p <= 0.0:
return
# First sort and calculate cumulative sum of probabilities.
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(indices_to_remove, float('-inf'))
def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
"""Sample from top-k logits.
Arguments:
logits: Tensor of shape (batch_size, vocab_size)
"""
if top_k == 1: # Short-circuit for greedy decoding
return logits.argmax(dim=-1)
else:
if top_p > 0.0:
assert top_p <= 1.0, 'top-p should be in (0, 1].'
if top_k > 0:
top_k = min(top_k, logits.size(-1)) # Safety check
logits_top, indices = torch.topk(logits, top_k, dim=-1)
logits_top /= temperature
modify_logits_for_top_p_filtering(logits_top, top_p)
return indices[
torch.arange(indices.shape[0], device=indices.device),
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
]
else:
logits_top = logits / temperature
modify_logits_for_top_p_filtering(logits_top, top_p)
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
eos_token_id=None, teacher_outputs=None, vocab_size=None, tensor_parallel=1,
fused_ft_kernel=False, cg=False, timing=False):
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
then top-p.
We assume that all sequences in the same batch have the same length.
Arguments:
input_ids: (batch, seq_len)
max_length: int
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
logits, the next token is taken from the teacher_outputs. Useful for testing.
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
sequences: (batch, max_length)
scores: tuples of (batch, vocab_size)
"""
batch_size, seqlen_og = input_ids.shape
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
if cg:
assert fused_ft_kernel
if not hasattr(model, '_decoding_cache'):
model._decoding_cache = None
model._decoding_cache = update_graph_cache(
model, model._decoding_cache, batch_size, seqlen_og, max_length,
tensor_parallel=tensor_parallel
)
inference_params = model._decoding_cache.inference_params
inference_params.max_sequence_len = max_length
inference_params.max_batch_size = batch_size
inference_params.sequence_len_offset = 0
else:
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size,
fused_ft_kernel=fused_ft_kernel)
scores = []
with torch.inference_mode():
if timing:
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize()
start = time.time()
logits = model(input_ids, inference_params=inference_params, last_token_only=True).logits
if vocab_size is not None:
logits = logits[..., :vocab_size]
scores.append(logits if not cg else logits.clone())
if teacher_outputs is None or teacher_output_len <= seqlen_og:
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
else:
next_token = teacher_outputs[:, seqlen_og]
sequences = [next_token]
inference_params.sequence_len_offset = seqlen_og
while True:
position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset,
dtype=torch.long, device=input_ids.device)
if not cg:
logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
inference_params=inference_params, last_token_only=True).logits
else:
logits = model._decoding_cache.run(rearrange(next_token, 'b -> b 1'), position_ids,
inference_params.sequence_len_offset)
if vocab_size is not None:
logits = logits[..., :vocab_size]
scores.append(logits if not cg else logits.clone())
if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset + 1:
next_token = sample(logits, top_k=top_k, temperature=temperature)
else:
next_token = teacher_outputs[:, inference_params.sequence_len_offset + 1]
sequences.append(next_token)
inference_params.sequence_len_offset += 1
if eos_token_id is not None and (next_token == eos_token_id).all():
break
if inference_params.sequence_len_offset >= max_length - 1:
break
if timing:
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
return output_cls(
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),
scores=tuple(scores)
)
class GenerationMixin:
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
raise NotImplementedError
def generate(self, input_ids, max_length, top_k=1, top_p=0.0, temperature=1.0,
return_dict_in_generate=False, output_scores=False, **kwargs):
output = decode(input_ids, self, max_length, top_k=top_k, top_p=top_p,
temperature=temperature, **kwargs)
if not output_scores:
output.scores = None
return output if return_dict_in_generate else output.sequences
def allocate_inference_cache(max_batch_size, max_seqlen, nheads, headdim, layers: Union[int, Sequence],
device, dtype=torch.float16):
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if dtype == torch.float32 else 8
assert headdim % packsize == 0
k_cache_shape = (max_batch_size, nheads, headdim // packsize, max_seqlen, packsize)
v_cache_shape = (max_batch_size, nheads, max_seqlen, headdim)
if isinstance(layers, int):
layers = range(layers)
return {i: (torch.empty(k_cache_shape, device=device, dtype=dtype),
torch.empty(v_cache_shape, device=device, dtype=dtype))
for i in layers}
def seqlen_to_seqlen_type(seqlen: int) -> int:
"""Convert sequence length to a seqlen_type.
This is used to determine which cuda graph to use.
Arguments:
seqlen: int
"""
return 0 if seqlen < 32 else (1 if seqlen < 2048 else 2)
def seqlen_type_to_max_seqlen(seqlen_type: int) -> int:
assert seqlen_type in [0, 1, 2]
return 32 if seqlen_type == 0 else (2048 if seqlen_type == 1 else 2**32)
@dataclass
class DecodingCGCache:
max_batch_size: int = 0
max_seqlen: int = 0
device = None
dtype = None
callables: dict = field(default_factory=dict)
mempool = None
inference_params: Optional[InferenceParams] = None
run: Optional[Callable] = None
@torch.inference_mode()
def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1,
dtype=None, n_warmups=2):
if cache is None:
cache = DecodingCGCache()
param_example = next(iter(model.parameters()))
device = param_example.device
if dtype is None:
dtype = param_example.dtype
if ((device, dtype) != (cache.device, cache.dtype) or batch_size > cache.max_batch_size
or max_seqlen > cache.max_seqlen): # Invalidate the cache
cache.callables = {}
cache.mempool = None
cache.inference_params = None
gc.collect()
cache.device, cache.dtype = device, dtype
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
if hasattr(model, 'allocate_inference_cache'):
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
else:
headdim = getattr(model.config, 'head_dim',
model.config.hidden_size // model.config.num_attention_heads)
inf_cache = allocate_inference_cache(
batch_size, max_seqlen, model.config.num_attention_heads // tensor_parallel, headdim,
model.config.num_hidden_layers, device, dtype
)
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
cache.inference_params = InferenceParams(
max_sequence_len=max_seqlen, max_batch_size=batch_size,
sequence_len_offset=seqlen_og, key_value_memory_dict=inf_cache, fused_ft_kernel=True,
lengths_per_sample=lengths_per_sample
)
cache.mempool = torch.cuda.graphs.graph_pool_handle()
for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1):
if (batch_size, s_type) not in cache.callables:
max_seqlen_ = min(max(seqlen_og, seqlen_type_to_max_seqlen(s_type)), max_seqlen)
cache.callables[batch_size, s_type] = capture_graph(
model, cache.inference_params, batch_size, max_seqlen_, mempool=cache.mempool,
n_warmups=n_warmups
)
def dispatch(input_ids, position_ids, seqlen):
batch_size = input_ids.shape[0]
return cache.callables[batch_size, seqlen_to_seqlen_type(seqlen)](input_ids, position_ids, seqlen)
cache.run = dispatch
cache.inference_params.sequence_len_offset = 0 # Reset so it's not confusing
return cache
def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, n_warmups=2):
device = next(iter(model.parameters())).device
input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
sequence_len_offset_og = inference_params.sequence_len_offset
# TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is
# used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample.
inference_params.sequence_len_offset = max_seqlen - 1
inference_params.lengths_per_sample[:] = max_seqlen - 1
# Warmup before capture
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(n_warmups):
logits = model(input_ids, position_ids=position_ids, inference_params=inference_params,
last_token_only=True).logits
s.synchronize()
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
# which requires that graph launch and non-captured launch to not overlap (I think,
# that's how I interpret the documentation). I'm not sure if this is required.
if torch.distributed.is_initialized():
torch.distributed.barrier()
torch.cuda.current_stream().wait_stream(s)
# Captures the graph
# To allow capture, automatically sets a side stream as the current stream in the context
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=mempool):
logits = model(input_ids, position_ids=position_ids, inference_params=inference_params,
last_token_only=True).logits
def run(new_input_ids, new_position_ids, seqlen):
inference_params.lengths_per_sample[:] = seqlen
input_ids.copy_(new_input_ids)
position_ids.copy_(new_position_ids)
graph.replay()
return logits
inference_params.sequence_len_offset = sequence_len_offset_og
return run

View File

@@ -0,0 +1,37 @@
import torch
from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from transformers.utils import is_remote_url
from transformers.modeling_utils import load_state_dict
from transformers.utils.hub import cached_file, get_checkpoint_shard_files
def state_dict_from_pretrained(model_name, device=None, dtype=None):
# If not fp32, then we don't want to load directly to the GPU
mapped_device = 'cpu' if dtype not in [torch.float32, None] else device
is_sharded = False
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
_raise_exceptions_for_missing_entries=False)
if resolved_archive_file is None:
resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME,
_raise_exceptions_for_missing_entries=False)
if resolved_archive_file is not None:
is_sharded = True
if resolved_archive_file is None:
raise EnvironmentError(f"Model name {model_name} was not found.")
if is_sharded:
# resolved_archive_file becomes a list of files that point to the different
# checkpoint shards in this case.
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
model_name, resolved_archive_file
)
state_dict = {}
for sharded_file in resolved_archive_file:
state_dict.update(torch.load(sharded_file, map_location=mapped_device))
else:
state_dict = torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device)
# Convert dtype before moving to GPU to save memory
if dtype is not None:
state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
return state_dict