First commit
This commit is contained in:
8
pkgs/xformers/_flash_attn/__init__.py
Normal file
8
pkgs/xformers/_flash_attn/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
__version__ = "2.5.8"
|
||||
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
from flash_attn.flash_attn_interface import flash_attn_kvpacked_func
|
||||
from flash_attn.flash_attn_interface import flash_attn_qkvpacked_func
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||
BIN
pkgs/xformers/_flash_attn/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
pkgs/xformers/_flash_attn/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
132
pkgs/xformers/_flash_attn/bert_padding.py
Normal file
132
pkgs/xformers/_flash_attn/bert_padding.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
class IndexFirstAxis(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, indices):
|
||||
ctx.save_for_backward(indices)
|
||||
assert input.ndim >= 2
|
||||
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
|
||||
second_dim = other_shape.numel()
|
||||
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
||||
# return input[indices]
|
||||
return torch.gather(rearrange(input, 'b ... -> b (...)'), 0,
|
||||
repeat(indices, 'z -> z d', d=second_dim)).reshape(-1, *other_shape)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
indices, = ctx.saved_tensors
|
||||
assert grad_output.ndim >= 2
|
||||
other_shape = grad_output.shape[1:]
|
||||
grad_output = rearrange(grad_output, 'b ... -> b (...)')
|
||||
grad_input = torch.zeros([ctx.first_axis_dim, grad_output.shape[1]],
|
||||
device=grad_output.device, dtype=grad_output.dtype)
|
||||
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
|
||||
# grad_input[indices] = grad_output
|
||||
grad_input.scatter_(0, repeat(indices, 'z -> z d', d=grad_output.shape[1]), grad_output)
|
||||
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
||||
|
||||
|
||||
index_first_axis = IndexFirstAxis.apply
|
||||
|
||||
|
||||
class IndexPutFirstAxis(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, values, indices, first_axis_dim):
|
||||
ctx.save_for_backward(indices)
|
||||
assert indices.ndim == 1
|
||||
assert values.ndim >= 2
|
||||
output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device,
|
||||
dtype=values.dtype)
|
||||
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
|
||||
output[indices] = values
|
||||
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
indices, = ctx.saved_tensors
|
||||
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
||||
grad_values = grad_output[indices]
|
||||
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
|
||||
return grad_values, None, None
|
||||
|
||||
|
||||
index_put_first_axis = IndexPutFirstAxis.apply
|
||||
|
||||
|
||||
class IndexFirstAxisResidual(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, indices):
|
||||
ctx.save_for_backward(indices)
|
||||
assert input.ndim >= 2
|
||||
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
|
||||
second_dim = other_shape.numel()
|
||||
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
||||
output = input[indices]
|
||||
# We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
|
||||
# memory format to channel_first. In other words, input might not be contiguous.
|
||||
# If we don't detach, Pytorch complains about output being a view and is being modified inplace
|
||||
return output, input.detach()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output, grad_residual):
|
||||
indices, = ctx.saved_tensors
|
||||
assert grad_output.ndim >= 2
|
||||
other_shape = grad_output.shape[1:]
|
||||
assert grad_residual.shape[1:] == other_shape
|
||||
grad_input = grad_residual
|
||||
# grad_input[indices] += grad_output
|
||||
indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
|
||||
indices = indices.expand_as(grad_output)
|
||||
grad_input.scatter_add_(0, indices, grad_output)
|
||||
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
||||
|
||||
|
||||
index_first_axis_residual = IndexFirstAxisResidual.apply
|
||||
|
||||
|
||||
def unpad_input(hidden_states, attention_mask):
|
||||
"""
|
||||
Arguments:
|
||||
hidden_states: (batch, seqlen, ...)
|
||||
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
||||
Return:
|
||||
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
||||
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
|
||||
max_seqlen_in_batch: int
|
||||
"""
|
||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
||||
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
||||
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
||||
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
|
||||
# so we write custom forward and backward to make it a bit faster.
|
||||
return (index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'), indices), indices,
|
||||
cu_seqlens, max_seqlen_in_batch)
|
||||
|
||||
|
||||
def pad_input(hidden_states, indices, batch, seqlen):
|
||||
"""
|
||||
Arguments:
|
||||
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
||||
indices: (total_nnz)
|
||||
Return:
|
||||
hidden_states: (batch, seqlen, ...)
|
||||
"""
|
||||
dim = hidden_states.shape[-1]
|
||||
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
# output[indices] = hidden_states
|
||||
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
|
||||
return rearrange(output, '(b s) ... -> b s ...', b=batch)
|
||||
966
pkgs/xformers/_flash_attn/flash_attn_interface.py
Normal file
966
pkgs/xformers/_flash_attn/flash_attn_interface.py
Normal file
@@ -0,0 +1,966 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import flash_attn_2_cuda as flash_attn_cuda
|
||||
from einops import rearrange
|
||||
import warnings
|
||||
|
||||
global enable_ixdnn
|
||||
enable_ixdnn = os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') != '0'
|
||||
|
||||
def _get_block_size(device, head_dim, is_dropout, is_causal):
|
||||
# This should match the block sizes in the CUDA kernel
|
||||
assert head_dim <= 256
|
||||
major, minor = torch.cuda.get_device_capability(device)
|
||||
is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100)
|
||||
is_sm80 = major == 8 and minor == 0
|
||||
is_sm90 = major == 9 and minor == 0
|
||||
if head_dim <= 32:
|
||||
return 128, 128
|
||||
if head_dim <= 64:
|
||||
return (128, 128) if not is_dropout else (128, 64)
|
||||
elif head_dim <= 96:
|
||||
return (64, 64) if (is_sm8x and is_causal) else (128, 64)
|
||||
elif head_dim <= 128:
|
||||
if is_sm8x:
|
||||
return (64, 64) if (not is_dropout and is_causal) else (128, 32)
|
||||
else:
|
||||
return 128, (64 if not is_dropout else 32)
|
||||
elif head_dim <= 160:
|
||||
if is_sm8x:
|
||||
return (128, 64) if not is_causal else (64, 64)
|
||||
else:
|
||||
return 128, 32
|
||||
elif head_dim <= 192:
|
||||
return (128, 64) if not is_dropout else (64, 64)
|
||||
elif head_dim <= 224:
|
||||
return (128, 64) if (is_sm80 or is_sm90) else (64, 64)
|
||||
elif head_dim <= 256:
|
||||
return (128, 64) if is_sm80 else (64, 64)
|
||||
|
||||
|
||||
def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax, use_alibi, alibi_mode):
|
||||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.fwd(
|
||||
q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, use_alibi, alibi_mode, None
|
||||
)
|
||||
return out, q, k, v, out_padded, softmax_lse, S_dmask
|
||||
|
||||
|
||||
def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal, return_softmax, use_alibi, alibi_mode):
|
||||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.varlen_fwd(
|
||||
q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||
softmax_scale, False, causal, return_softmax, use_alibi, alibi_mode, None
|
||||
)
|
||||
# if out.isnan().any() or softmax_lse.isnan().any():
|
||||
# breakpoint()
|
||||
return out, q, k, v, out_padded, softmax_lse, S_dmask
|
||||
|
||||
|
||||
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
|
||||
dropout_p, softmax_scale, causal, use_alibi, alibi_mode):
|
||||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||
# dq, dk, dv are allocated by us so they should already be contiguous
|
||||
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
||||
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
|
||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, use_alibi, alibi_mode, None
|
||||
)
|
||||
return dq, dk, dv, softmax_d
|
||||
|
||||
|
||||
def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
|
||||
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal, use_alibi, alibi_mode):
|
||||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||
# dq, dk, dv are allocated by us so they should already be contiguous
|
||||
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
||||
dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
|
||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
||||
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, use_alibi, alibi_mode, None
|
||||
)
|
||||
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
|
||||
# breakpoint()
|
||||
return dq, dk, dv, softmax_d
|
||||
|
||||
def _flash_attn_forward_ixdnn(q, k, v, cu_seqlens_q, cu_seqlens_k, dropout_p,
|
||||
causal, return_softmax, use_alibi, alibi_mode, imp_mode, window_size):
|
||||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.fwd_ixdnn(
|
||||
q, k, v, None, cu_seqlens_q, cu_seqlens_k, dropout_p, causal, return_softmax,
|
||||
use_alibi, alibi_mode, imp_mode, window_size[0], window_size[1], None
|
||||
)
|
||||
return out, q, k, v, out_padded, softmax_lse, S_dmask
|
||||
|
||||
def _flash_attn_backward_ixdnn(dout, q, k, v, out, softmax_lse, dq, dk, dv,
|
||||
cu_seqlens_q, cu_seqlens_k, dropout_p, causal, use_alibi, alibi_mode, imp_mode, window_size):
|
||||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||
# dq, dk, dv are allocated by us so they should already be contiguous
|
||||
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
||||
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd_ixdnn(
|
||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, dropout_p,
|
||||
causal, use_alibi, alibi_mode, imp_mode, window_size[0], window_size[1], None
|
||||
)
|
||||
return dq, dk, dv, softmax_d
|
||||
|
||||
def _flash_attn_varlen_forward_ixdnn(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, causal, use_alibi, alibi_mode, imp_mode, window_size):
|
||||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.varlen_fwd_ixdnn(
|
||||
q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||
causal, use_alibi, alibi_mode, imp_mode, window_size[0], window_size[1], None
|
||||
)
|
||||
return out, q, k, v, out_padded, softmax_lse, S_dmask
|
||||
|
||||
def _flash_attn_varlen_backward_ixdnn(dout, q, k, v, out, softmax_lse, dq, dk, dv,
|
||||
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, causal, use_alibi, alibi_mode, imp_mode, window_size):
|
||||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||
# dq, dk, dv are allocated by us so they should already be contiguous
|
||||
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
||||
dq, dk, dv, softmax_d = flash_attn_cuda.varlen_bwd_ixdnn(
|
||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
||||
max_seqlen_q, max_seqlen_k, dropout_p, causal, use_alibi, alibi_mode, imp_mode,
|
||||
window_size[0], window_size[1], None
|
||||
)
|
||||
return dq, dk, dv, softmax_d
|
||||
|
||||
class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax, use_alibi, alibi_mode, imp_mode):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||
if alibi_slopes is not None and use_alibi == False:
|
||||
warnings.warn("Parameter 'alibi_slopes' is not supported, automatically switching to the supported parameters 'use_alibi=True' and 'alibi_mode=1'", UserWarning)
|
||||
use_alibi = True
|
||||
imp_mode = 1
|
||||
if enable_ixdnn:
|
||||
cu_seqlens = torch.full((qkv.shape[0],), qkv.shape[1], dtype=torch.int32, device=qkv.device)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward_ixdnn(
|
||||
qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], cu_seqlens, cu_seqlens, dropout_p,
|
||||
causal=causal, return_softmax=return_softmax and dropout_p > 0,
|
||||
use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode,
|
||||
window_size=window_size
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
|
||||
else:
|
||||
if softmax_scale is None:
|
||||
softmax_scale = qkv.shape[-1] ** (-0.5)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
|
||||
qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], dropout_p, softmax_scale,
|
||||
causal=causal, return_softmax=return_softmax and dropout_p > 0,
|
||||
use_alibi=use_alibi, alibi_mode=alibi_mode
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.softmax_scale = softmax_scale
|
||||
ctx.causal = causal
|
||||
ctx.use_alibi = use_alibi
|
||||
ctx.alibi_mode = alibi_mode
|
||||
ctx.imp_mode = imp_mode
|
||||
ctx.window_size = window_size
|
||||
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
if enable_ixdnn:
|
||||
q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
|
||||
else:
|
||||
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
|
||||
if rng_state is not None:
|
||||
cur_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
|
||||
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
|
||||
if enable_ixdnn:
|
||||
_flash_attn_backward_ixdnn(
|
||||
dout, q, k, v, out, softmax_lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
|
||||
cu_seqlens, cu_seqlens, ctx.dropout_p, ctx.causal, ctx.use_alibi, ctx.alibi_mode, ctx.imp_mode,
|
||||
ctx.window_size
|
||||
)
|
||||
else:
|
||||
_flash_attn_backward(
|
||||
dout, q, k, v, out, softmax_lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
|
||||
ctx.dropout_p, ctx.softmax_scale, ctx.causal, ctx.use_alibi, ctx.alibi_mode
|
||||
)
|
||||
dqkv = dqkv[..., :dout.shape[-1]] # We could have padded the head dimension
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dqkv, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax, use_alibi, alibi_mode, imp_mode):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||
if alibi_slopes is not None and use_alibi == False:
|
||||
warnings.warn("Parameter 'alibi_slopes' is not supported, automatically switching to the supported parameters 'use_alibi=True' and 'alibi_mode=1'", UserWarning)
|
||||
use_alibi = True
|
||||
imp_mode = 1
|
||||
if enable_ixdnn:
|
||||
cu_seqlens = torch.full((cu_seqlens.numel(),), max_seqlen, dtype=torch.int32, device=cu_seqlens.device)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward_ixdnn(
|
||||
qkv[:, 0], qkv[:, 1], qkv[:, 2], cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
|
||||
dropout_p, causal=causal, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode,
|
||||
window_size=window_size
|
||||
)
|
||||
else:
|
||||
if softmax_scale is None:
|
||||
softmax_scale = qkv.shape[-1] ** (-0.5)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward(
|
||||
qkv[:, 0], qkv[:, 1], qkv[:, 2], cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
|
||||
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0,
|
||||
use_alibi=use_alibi, alibi_mode=alibi_mode
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.max_seqlen = max_seqlen
|
||||
ctx.softmax_scale = softmax_scale
|
||||
ctx.causal = causal
|
||||
ctx.use_alibi = use_alibi
|
||||
ctx.alibi_mode = alibi_mode
|
||||
ctx.imp_mode = imp_mode
|
||||
ctx.window_size = window_size
|
||||
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
|
||||
if rng_state is not None:
|
||||
cur_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
|
||||
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
|
||||
if enable_ixdnn:
|
||||
_flash_attn_varlen_backward_ixdnn(
|
||||
dout, q, k, v, out, softmax_lse, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2],
|
||||
cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen,
|
||||
ctx.dropout_p, ctx.causal, ctx.use_alibi, ctx.alibi_mode, ctx.imp_mode,
|
||||
ctx.window_size
|
||||
)
|
||||
else:
|
||||
_flash_attn_varlen_backward(
|
||||
dout, q, k, v, out, softmax_lse, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2],
|
||||
cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen,
|
||||
ctx.dropout_p, ctx.softmax_scale, ctx.causal, ctx.use_alibi, ctx.alibi_mode
|
||||
)
|
||||
dqkv = dqkv[..., :dout.shape[-1]] # We could have padded the head dimension
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dqkv, None, None, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class FlashAttnKVPackedFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax, use_alibi, alibi_mode, imp_mode):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||
if alibi_slopes is not None and use_alibi == False:
|
||||
warnings.warn("Parameter 'alibi_slopes' is not supported, automatically switching to the supported parameters 'use_alibi=True' and 'alibi_mode=1'", UserWarning)
|
||||
use_alibi = True
|
||||
imp_mode = 1
|
||||
if enable_ixdnn:
|
||||
cu_seqlens_q = torch.full((q.shape[0],), q.shape[1], dtype=torch.int32, device=q.device)
|
||||
cu_seqlens_k = torch.full((kv.shape[0],), kv.shape[1], dtype=torch.int32, device=kv.device)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward_ixdnn(
|
||||
q, kv[:, :, 0], kv[:, :, 1], cu_seqlens_q, cu_seqlens_k, dropout_p, causal=causal,
|
||||
return_softmax=return_softmax and dropout_p > 0, use_alibi=use_alibi,
|
||||
alibi_mode=alibi_mode, imp_mode=imp_mode, window_size=window_size
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state)
|
||||
else:
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
|
||||
q, kv[:, :, 0], kv[:, :, 1], dropout_p, softmax_scale, causal=causal,
|
||||
return_softmax=return_softmax and dropout_p > 0, use_alibi=use_alibi,
|
||||
alibi_mode=alibi_mode
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.softmax_scale = softmax_scale
|
||||
ctx.causal = causal
|
||||
ctx.use_alibi = use_alibi
|
||||
ctx.alibi_mode = alibi_mode
|
||||
ctx.imp_mode = imp_mode
|
||||
ctx.window_size = window_size
|
||||
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
if enable_ixdnn:
|
||||
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
|
||||
else:
|
||||
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
|
||||
if rng_state is not None:
|
||||
cur_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
dq = torch.empty_like(q)
|
||||
kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
|
||||
dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
|
||||
if enable_ixdnn:
|
||||
_flash_attn_backward_ixdnn(
|
||||
dout, q, k, v, out, softmax_lse, dq, dkv[:, :, 0], dkv[:, :, 1],
|
||||
cu_seqlens_q, cu_seqlens_k, ctx.dropout_p, ctx.causal, ctx.use_alibi, ctx.alibi_mode, ctx.imp_mode, ctx.window_size
|
||||
)
|
||||
else:
|
||||
_flash_attn_backward(
|
||||
dout, q, k, v, out, softmax_lse,
|
||||
dq, dkv[:, :, 0], dkv[:, :, 1], ctx.dropout_p, ctx.softmax_scale, ctx.causal, ctx.use_alibi, ctx.alibi_mode
|
||||
)
|
||||
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
|
||||
dkv = dkv[..., :dout.shape[-1]]
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dq, dkv, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||
softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax,
|
||||
use_alibi, alibi_mode, imp_mode):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||
if alibi_slopes is not None and use_alibi == False:
|
||||
warnings.warn("Parameter 'alibi_slopes' is not supported, automatically switching to the supported parameters 'use_alibi=True' and 'alibi_mode=1'", UserWarning)
|
||||
use_alibi = True
|
||||
imp_mode = 1
|
||||
if enable_ixdnn:
|
||||
cu_seqlens_q = torch.full((cu_seqlens_q.numel(),), max_seqlen_q, dtype=torch.int32, device=cu_seqlens_q.device)
|
||||
cu_seqlens_k = torch.full((cu_seqlens_k.numel(),), max_seqlen_k, dtype=torch.int32, device=cu_seqlens_k.device)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward_ixdnn(
|
||||
q, kv[:, 0], kv[:, 1], cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, causal=causal, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode,
|
||||
window_size=window_size
|
||||
)
|
||||
else:
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward(
|
||||
q, kv[:, 0], kv[:, 1], cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0,
|
||||
use_alibi=use_alibi, alibi_mode=alibi_mode
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, out_padded, softmax_lse,
|
||||
cu_seqlens_q, cu_seqlens_k, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.max_seqlen_q = max_seqlen_q
|
||||
ctx.max_seqlen_k = max_seqlen_k
|
||||
ctx.softmax_scale = softmax_scale
|
||||
ctx.causal = causal
|
||||
ctx.use_alibi = use_alibi
|
||||
ctx.alibi_mode = alibi_mode
|
||||
ctx.imp_mode = imp_mode
|
||||
ctx.window_size = window_size
|
||||
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
|
||||
if rng_state is not None:
|
||||
cur_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
dq = torch.empty_like(q)
|
||||
kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
|
||||
dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
|
||||
if enable_ixdnn:
|
||||
_flash_attn_varlen_backward_ixdnn(
|
||||
dout, q, k, v, out, softmax_lse, dq, dkv[:, 0], dkv[:, 1],
|
||||
cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k,
|
||||
ctx.dropout_p, ctx.causal, ctx.use_alibi, ctx.alibi_mode, ctx.imp_mode,
|
||||
ctx.window_size
|
||||
)
|
||||
else:
|
||||
_flash_attn_varlen_backward(
|
||||
dout, q, k, v, out, softmax_lse, dq, dkv[:, 0], dkv[:, 1],
|
||||
cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k,
|
||||
ctx.dropout_p, ctx.softmax_scale, ctx.causal, ctx.use_alibi, ctx.alibi_mode
|
||||
)
|
||||
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
|
||||
dkv = dkv[..., :dout.shape[-1]]
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class FlashAttnFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax, use_alibi, alibi_mode, imp_mode):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||
if alibi_slopes is not None and use_alibi == False:
|
||||
warnings.warn("Parameter 'alibi_slopes' is not supported, automatically switching to the supported parameters 'use_alibi=True' and 'alibi_mode=1'", UserWarning)
|
||||
use_alibi = True
|
||||
imp_mode = 1
|
||||
if enable_ixdnn:
|
||||
cu_seqlens_q = torch.full((q.shape[0],), q.shape[1], dtype=torch.int32, device=q.device)
|
||||
cu_seqlens_k = torch.full((k.shape[0],), k.shape[1], dtype=torch.int32, device=k.device)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward_ixdnn(
|
||||
q, k, v, cu_seqlens_q, cu_seqlens_k, dropout_p, causal=causal,
|
||||
return_softmax=return_softmax and dropout_p > 0, use_alibi=use_alibi,
|
||||
alibi_mode = alibi_mode, imp_mode = imp_mode, window_size=window_size
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state)
|
||||
else:
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
|
||||
q, k, v, dropout_p, softmax_scale, causal=causal,
|
||||
return_softmax=return_softmax and dropout_p > 0, use_alibi=use_alibi,
|
||||
alibi_mode = alibi_mode
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.softmax_scale = softmax_scale
|
||||
ctx.causal = causal
|
||||
ctx.use_alibi = use_alibi
|
||||
ctx.alibi_mode = alibi_mode
|
||||
ctx.imp_mode = imp_mode
|
||||
ctx.window_size = window_size
|
||||
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
if enable_ixdnn:
|
||||
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
|
||||
else:
|
||||
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
|
||||
if rng_state is not None:
|
||||
cur_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
||||
if enable_ixdnn:
|
||||
_flash_attn_backward_ixdnn(
|
||||
dout, q, k, v, out, softmax_lse,
|
||||
dq, dk, dv, cu_seqlens_q, cu_seqlens_k, ctx.dropout_p,
|
||||
ctx.causal, ctx.use_alibi, ctx.alibi_mode, ctx.imp_mode,
|
||||
ctx.window_size
|
||||
)
|
||||
else:
|
||||
_flash_attn_backward(
|
||||
dout, q, k, v, out, softmax_lse,
|
||||
dq, dk, dv, ctx.dropout_p, ctx.softmax_scale, ctx.causal, ctx.use_alibi, ctx.alibi_mode
|
||||
)
|
||||
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
|
||||
dk = dk[..., :dout.shape[-1]]
|
||||
dv = dv[..., :dout.shape[-1]]
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class FlashAttnVarlenFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||
softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax,
|
||||
use_alibi, alibi_mode, imp_mode):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||
if alibi_slopes is not None and use_alibi == False:
|
||||
warnings.warn("Parameter 'alibi_slopes' is not supported, automatically switching to the supported parameters 'use_alibi=True' and 'alibi_mode=1'", UserWarning)
|
||||
use_alibi = True
|
||||
imp_mode = 1
|
||||
if enable_ixdnn:
|
||||
cu_seqlens_q = torch.full((cu_seqlens_q.numel(),), max_seqlen_q, dtype=torch.int32, device=cu_seqlens_q.device)
|
||||
cu_seqlens_k = torch.full((cu_seqlens_k.numel(),), max_seqlen_k, dtype=torch.int32, device=cu_seqlens_k.device)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward_ixdnn(
|
||||
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, causal=causal, use_alibi=use_alibi, alibi_mode = alibi_mode, imp_mode=imp_mode,
|
||||
window_size=window_size
|
||||
)
|
||||
else:
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward(
|
||||
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0,
|
||||
use_alibi = use_alibi, alibi_mode = alibi_mode
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, out_padded, softmax_lse,
|
||||
cu_seqlens_q, cu_seqlens_k, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.max_seqlen_q = max_seqlen_q
|
||||
ctx.max_seqlen_k = max_seqlen_k
|
||||
ctx.softmax_scale = softmax_scale
|
||||
ctx.causal = causal
|
||||
ctx.use_alibi = use_alibi
|
||||
ctx.alibi_mode = alibi_mode
|
||||
ctx.imp_mode = imp_mode
|
||||
ctx.window_size = window_size
|
||||
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
|
||||
if rng_state is not None:
|
||||
cur_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
||||
if enable_ixdnn:
|
||||
_flash_attn_varlen_backward_ixdnn(
|
||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
||||
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.causal, ctx.use_alibi,
|
||||
ctx.alibi_mode, ctx.imp_mode, ctx.window_size
|
||||
)
|
||||
else:
|
||||
_flash_attn_varlen_backward(
|
||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
||||
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
||||
ctx.use_alibi, ctx.alibi_mode
|
||||
)
|
||||
|
||||
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
|
||||
dk = dk[..., :dout.shape[-1]]
|
||||
dv = dv[..., :dout.shape[-1]]
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
def flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), alibi_slopes=None,
|
||||
deterministic=False, return_attn_probs=False, use_alibi=False, alibi_mode=1, imp_mode=0):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
If Q, K, V are already stacked into 1 tensor, this function will be faster than
|
||||
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
|
||||
of the gradients of Q, K, V.
|
||||
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
||||
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
|
||||
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
||||
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
||||
|
||||
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
||||
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
|
||||
|
||||
Arguments:
|
||||
qkv: (batch_size, seqlen, 3, nheads, headdim)
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
||||
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
|
||||
the attention score of query i and key j.
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
use_alibi: bool. Whether to apply attention with linear bias(ALiBi). if True, causal must
|
||||
be True. The calculation formula(QK^T + bias * slops_m) is shown as follows,
|
||||
q1*k1 0
|
||||
q2*k1 q2*k2 -1 0
|
||||
q3*k1 q3*k2 q3*k3 + -2 -1 0 * slops_m
|
||||
q4*k1 q4*k2 q4*k3 q4*k4 -3 -2 -1 0
|
||||
q5*k1 q5*k2 q5*k3 q5*k4 q5*k5 -4 -3 -2 -1 0
|
||||
|
||||
alibi_mode: int. The bias mode of ALiBi, default to 1.
|
||||
alibi_mode=0: alibi_mode=1:
|
||||
0 -√0
|
||||
-1 0 -√1 -√0
|
||||
-2 -1 0 -√2 -√1 -√0
|
||||
-3 -2 -1 0 -√3 -√2 -√1 -√0
|
||||
-4 -3 -2 -1 0 -√4 -√3 -√2 -√1 -√0
|
||||
-5 -4 -3 -2 -1 0 -√5 -√4 -√3 -√2 -√1 -√0
|
||||
imp_mode: int. Support two modes of backward implementation, default to 0.
|
||||
imp_mode=0(CUDNN_FATTN_BALANCE_MODE): Implement dQ reduction calculation by launching a new kernel,
|
||||
which will occupy additional memory. This mode usually has better performance.
|
||||
imp_mode=1(CUDNN_FATTN_LEAST_MEM_MODE): Implement dQ reduction calculation within the block,
|
||||
which will not occupy additional memory. There will be a slight performance loss in this mode.
|
||||
This mode can be considered in memory-limited scenarios.
|
||||
Return:
|
||||
out: (batch_size, seqlen, nheads, headdim).
|
||||
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
||||
The output of softmax (possibly with different scaling). It also encodes the dropout
|
||||
pattern (negative means that location was dropped, nonnegative means it was kept).
|
||||
"""
|
||||
return FlashAttnQKVPackedFunc.apply(qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_attn_probs, use_alibi, alibi_mode, imp_mode)
|
||||
|
||||
|
||||
def flash_attn_kvpacked_func(q, kv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), alibi_slopes=None,
|
||||
deterministic=False, return_attn_probs=False, use_alibi=False, alibi_mode=1, imp_mode=0):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
If K, V are already stacked into 1 tensor, this function will be faster than
|
||||
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
|
||||
of the gradients of K, V.
|
||||
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
||||
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
|
||||
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
||||
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
||||
|
||||
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
||||
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
||||
1 1 1 1 0
|
||||
1 1 1 1 1
|
||||
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
||||
0 0
|
||||
0 0
|
||||
0 0
|
||||
1 0
|
||||
1 1
|
||||
If the row of the mask is all zero, the output will be zero.
|
||||
|
||||
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
||||
will only attend to keys between
|
||||
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
|
||||
|
||||
Arguments:
|
||||
q: (batch_size, seqlen, nheads, headdim)
|
||||
kv: (batch_size, seqlen, 2, nheads_k, headdim)
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
||||
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
||||
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
||||
is added to the attention score of query i and key j.
|
||||
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
||||
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
use_alibi: bool. Whether to apply attention with linear bias(ALiBi). if True, causal must
|
||||
be True. The calculation formula(QK^T + bias * slops_m) is shown as follows,
|
||||
q1*k1 0
|
||||
q2*k1 q2*k2 -1 0
|
||||
q3*k1 q3*k2 q3*k3 + -2 -1 0 * slops_m
|
||||
q4*k1 q4*k2 q4*k3 q4*k4 -3 -2 -1 0
|
||||
q5*k1 q5*k2 q5*k3 q5*k4 q5*k5 -4 -3 -2 -1 0
|
||||
|
||||
alibi_mode: int. The bias mode of ALiBi, default to 1.
|
||||
alibi_mode=0: alibi_mode=1:
|
||||
0 -√0
|
||||
-1 0 -√1 -√0
|
||||
-2 -1 0 -√2 -√1 -√0
|
||||
-3 -2 -1 0 -√3 -√2 -√1 -√0
|
||||
-4 -3 -2 -1 0 -√4 -√3 -√2 -√1 -√0
|
||||
-5 -4 -3 -2 -1 0 -√5 -√4 -√3 -√2 -√1 -√0
|
||||
imp_mode: int. Support two modes of backward implementation, default to 0.
|
||||
imp_mode=0(CUDNN_FATTN_BALANCE_MODE): Implement dQ reduction calculation by launching a new kernel,
|
||||
which will occupy additional memory. This mode usually has better performance.
|
||||
imp_mode=1(CUDNN_FATTN_LEAST_MEM_MODE): Implement dQ reduction calculation within the block,
|
||||
which will not occupy additional memory. There will be a slight performance loss in this mode.
|
||||
This mode can be considered in memory-limited scenarios.
|
||||
Return:
|
||||
out: (batch_size, seqlen, nheads, headdim).
|
||||
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
||||
The output of softmax (possibly with different scaling). It also encodes the dropout
|
||||
pattern (negative means that location was dropped, nonnegative means it was kept).
|
||||
"""
|
||||
return FlashAttnKVPackedFunc.apply(q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_attn_probs, use_alibi, alibi_mode, imp_mode)
|
||||
|
||||
def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), alibi_slopes=None,
|
||||
deterministic=False, return_attn_probs=False, use_alibi=False, alibi_mode=1, imp_mode=0):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
||||
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
|
||||
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
||||
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
||||
|
||||
Arguments:
|
||||
q: (batch_size, seqlen, nheads, headdim)
|
||||
k: (batch_size, seqlen, nheads_k, headdim)
|
||||
v: (batch_size, seqlen, nheads_k, headdim)
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
||||
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
||||
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
||||
is added to the attention score of query i and key j.
|
||||
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
||||
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
use_alibi: bool. Whether to apply attention with linear bias(ALiBi). if True, causal must
|
||||
be True. The calculation formula(QK^T + bias * slops_m) is shown as follows,
|
||||
q1*k1 0
|
||||
q2*k1 q2*k2 -1 0
|
||||
q3*k1 q3*k2 q3*k3 + -2 -1 0 * slops_m
|
||||
q4*k1 q4*k2 q4*k3 q4*k4 -3 -2 -1 0
|
||||
q5*k1 q5*k2 q5*k3 q5*k4 q5*k5 -4 -3 -2 -1 0
|
||||
|
||||
alibi_mode: int. The bias mode of ALiBi, default to 1.
|
||||
alibi_mode=0: alibi_mode=1:
|
||||
0 -√0
|
||||
-1 0 -√1 -√0
|
||||
-2 -1 0 -√2 -√1 -√0
|
||||
-3 -2 -1 0 -√3 -√2 -√1 -√0
|
||||
-4 -3 -2 -1 0 -√4 -√3 -√2 -√1 -√0
|
||||
-5 -4 -3 -2 -1 0 -√5 -√4 -√3 -√2 -√1 -√0
|
||||
imp_mode: int. Support two modes of backward implementation, default to 0.
|
||||
imp_mode=0(CUDNN_FATTN_BALANCE_MODE): Implement dQ reduction calculation by launching a new kernel,
|
||||
which will occupy additional memory. This mode usually has better performance.
|
||||
imp_mode=1(CUDNN_FATTN_LEAST_MEM_MODE): Implement dQ reduction calculation within the block,
|
||||
which will not occupy additional memory. There will be a slight performance loss in this mode.
|
||||
This mode can be considered in memory-limited scenarios.
|
||||
Return:
|
||||
out: (batch_size, seqlen, nheads, headdim).
|
||||
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
||||
The output of softmax (possibly with different scaling). It also encodes the dropout
|
||||
pattern (negative means that location was dropped, nonnegative means it was kept).
|
||||
"""
|
||||
return FlashAttnFunc.apply(q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_attn_probs, use_alibi, alibi_mode, imp_mode)
|
||||
|
||||
|
||||
def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p=0.0, softmax_scale=None,
|
||||
causal=False, window_size=(-1, -1), alibi_slopes=None, deterministic=False,
|
||||
return_attn_probs=False, use_alibi=False, alibi_mode=1, imp_mode=0):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
If Q, K, V are already stacked into 1 tensor, this function will be faster than
|
||||
calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
|
||||
of the gradients of Q, K, V.
|
||||
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
||||
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
||||
|
||||
Arguments:
|
||||
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
|
||||
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into qkv.
|
||||
max_seqlen: int. Maximum sequence length in the batch.
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
||||
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
|
||||
is added to the attention score of query i and key j.
|
||||
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
||||
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
use_alibi: bool. Whether to apply attention with linear bias(ALiBi). if True, causal must
|
||||
be True. The calculation formula(QK^T + bias * slops_m) is shown as follows,
|
||||
q1*k1 0
|
||||
q2*k1 q2*k2 -1 0
|
||||
q3*k1 q3*k2 q3*k3 + -2 -1 0 * slops_m
|
||||
q4*k1 q4*k2 q4*k3 q4*k4 -3 -2 -1 0
|
||||
q5*k1 q5*k2 q5*k3 q5*k4 q5*k5 -4 -3 -2 -1 0
|
||||
|
||||
alibi_mode: int. The bias mode of ALiBi, default to 1.
|
||||
alibi_mode=0: alibi_mode=1:
|
||||
0 -√0
|
||||
-1 0 -√1 -√0
|
||||
-2 -1 0 -√2 -√1 -√0
|
||||
-3 -2 -1 0 -√3 -√2 -√1 -√0
|
||||
-4 -3 -2 -1 0 -√4 -√3 -√2 -√1 -√0
|
||||
-5 -4 -3 -2 -1 0 -√5 -√4 -√3 -√2 -√1 -√0
|
||||
imp_mode: int. Support two modes of backward implementation, default to 0.
|
||||
imp_mode=0(CUDNN_FATTN_BALANCE_MODE): Implement dQ reduction calculation by launching a new kernel,
|
||||
which will occupy additional memory. This mode usually has better performance.
|
||||
imp_mode=1(CUDNN_FATTN_LEAST_MEM_MODE): Implement dQ reduction calculation within the block,
|
||||
which will not occupy additional memory. There will be a slight performance loss in this mode.
|
||||
This mode can be considered in memory-limited scenarios.
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
||||
The output of softmax (possibly with different scaling). It also encodes the dropout
|
||||
pattern (negative means that location was dropped, nonnegative means it was kept).
|
||||
"""
|
||||
return FlashAttnVarlenQKVPackedFunc.apply(
|
||||
qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, window_size,
|
||||
alibi_slopes, deterministic, return_attn_probs, use_alibi, alibi_mode, imp_mode
|
||||
)
|
||||
|
||||
|
||||
def flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1),
|
||||
alibi_slopes=None, deterministic=False, return_attn_probs=False,
|
||||
use_alibi=False, alibi_mode=1, imp_mode=0):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
If K, V are already stacked into 1 tensor, this function will be faster than
|
||||
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
|
||||
of the gradients of K, V.
|
||||
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
||||
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
|
||||
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
||||
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
||||
|
||||
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
||||
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
||||
1 1 1 1 0
|
||||
1 1 1 1 1
|
||||
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
||||
0 0
|
||||
0 0
|
||||
0 0
|
||||
1 0
|
||||
1 1
|
||||
If the row of the mask is all zero, the output will be zero.
|
||||
|
||||
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
||||
will only attend to keys between
|
||||
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
|
||||
|
||||
Arguments:
|
||||
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
||||
kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
||||
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into q.
|
||||
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into kv.
|
||||
max_seqlen_q: int. Maximum query sequence length in the batch.
|
||||
max_seqlen_k: int. Maximum key sequence length in the batch.
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
||||
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
||||
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
||||
is added to the attention score of query i and key j.
|
||||
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
||||
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
use_alibi: bool. Whether to apply attention with linear bias(ALiBi). if True, causal must
|
||||
be True. The calculation formula(QK^T + bias * slops_m) is shown as follows,
|
||||
q1*k1 0
|
||||
q2*k1 q2*k2 -1 0
|
||||
q3*k1 q3*k2 q3*k3 + -2 -1 0 * slops_m
|
||||
q4*k1 q4*k2 q4*k3 q4*k4 -3 -2 -1 0
|
||||
q5*k1 q5*k2 q5*k3 q5*k4 q5*k5 -4 -3 -2 -1 0
|
||||
|
||||
alibi_mode: int. The bias mode of ALiBi, default to 1.
|
||||
alibi_mode=0: alibi_mode=1:
|
||||
0 -√0
|
||||
-1 0 -√1 -√0
|
||||
-2 -1 0 -√2 -√1 -√0
|
||||
-3 -2 -1 0 -√3 -√2 -√1 -√0
|
||||
-4 -3 -2 -1 0 -√4 -√3 -√2 -√1 -√0
|
||||
-5 -4 -3 -2 -1 0 -√5 -√4 -√3 -√2 -√1 -√0
|
||||
imp_mode: int. Support two modes of backward implementation, default to 0.
|
||||
imp_mode=0(CUDNN_FATTN_BALANCE_MODE): Implement dQ reduction calculation by launching a new kernel,
|
||||
which will occupy additional memory. This mode usually has better performance.
|
||||
imp_mode=1(CUDNN_FATTN_LEAST_MEM_MODE): Implement dQ reduction calculation within the block,
|
||||
which will not occupy additional memory. There will be a slight performance loss in this mode.
|
||||
This mode can be considered in memory-limited scenarios.
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
||||
The output of softmax (possibly with different scaling). It also encodes the dropout
|
||||
pattern (negative means that location was dropped, nonnegative means it was kept).
|
||||
"""
|
||||
return FlashAttnVarlenKVPackedFunc.apply(
|
||||
q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal, window_size, alibi_slopes,
|
||||
deterministic, return_attn_probs, use_alibi, alibi_mode, imp_mode
|
||||
)
|
||||
|
||||
|
||||
def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1),
|
||||
alibi_slopes=None, deterministic=False, return_attn_probs=False,
|
||||
use_alibi=False, alibi_mode=1, imp_mode=0):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
|
||||
than Q. Note that the number of heads in K, V must be divisible by the number of heads in Q.
|
||||
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
||||
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
||||
|
||||
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
||||
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
||||
1 1 1 1 0
|
||||
1 1 1 1 1
|
||||
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
||||
0 0
|
||||
0 0
|
||||
0 0
|
||||
1 0
|
||||
1 1
|
||||
If the row of the mask is all zero, the output will be zero.
|
||||
|
||||
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
||||
will only attend to keys between
|
||||
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
|
||||
|
||||
Arguments:
|
||||
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
||||
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
||||
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
||||
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into q.
|
||||
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into kv.
|
||||
max_seqlen_q: int. Maximum query sequence length in the batch.
|
||||
max_seqlen_k: int. Maximum key sequence length in the batch.
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
||||
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
||||
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
||||
is added to the attention score of query i and key j.
|
||||
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
||||
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
use_alibi: bool. Whether to apply attention with linear bias(ALiBi). if True, causal must
|
||||
be True. The calculation formula(QK^T + bias * slops_m) is shown as follows,
|
||||
q1*k1 0
|
||||
q2*k1 q2*k2 -1 0
|
||||
q3*k1 q3*k2 q3*k3 + -2 -1 0 * slops_m
|
||||
q4*k1 q4*k2 q4*k3 q4*k4 -3 -2 -1 0
|
||||
q5*k1 q5*k2 q5*k3 q5*k4 q5*k5 -4 -3 -2 -1 0
|
||||
|
||||
alibi_mode: int. The bias mode of ALiBi, default to 1.
|
||||
alibi_mode=0: alibi_mode=1:
|
||||
0 -√0
|
||||
-1 0 -√1 -√0
|
||||
-2 -1 0 -√2 -√1 -√0
|
||||
-3 -2 -1 0 -√3 -√2 -√1 -√0
|
||||
-4 -3 -2 -1 0 -√4 -√3 -√2 -√1 -√0
|
||||
-5 -4 -3 -2 -1 0 -√5 -√4 -√3 -√2 -√1 -√0
|
||||
imp_mode: int. Support two modes of backward implementation, default to 0.
|
||||
imp_mode=0(CUDNN_FATTN_BALANCE_MODE): Implement dQ reduction calculation by launching a new kernel,
|
||||
which will occupy additional memory. This mode usually has better performance.
|
||||
imp_mode=1(CUDNN_FATTN_LEAST_MEM_MODE): Implement dQ reduction calculation within the block,
|
||||
which will not occupy additional memory. There will be a slight performance loss in this mode.
|
||||
This mode can be considered in memory-limited scenarios.
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
||||
The output of softmax (possibly with different scaling). It also encodes the dropout
|
||||
pattern (negative means that location was dropped, nonnegative means it was kept).
|
||||
"""
|
||||
return FlashAttnVarlenFunc.apply(
|
||||
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic,
|
||||
return_attn_probs, use_alibi, alibi_mode, imp_mode
|
||||
)
|
||||
832
pkgs/xformers/_flash_attn/flash_attn_triton.py
Normal file
832
pkgs/xformers/_flash_attn/flash_attn_triton.py
Normal file
@@ -0,0 +1,832 @@
|
||||
"""
|
||||
*Experimental* implementation of FlashAttention in Triton.
|
||||
Tested with triton==2.0.0.dev20221202.
|
||||
Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
|
||||
other than 64:
|
||||
https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
|
||||
We'll update this implementation with the new Triton backend once this is fixed.
|
||||
|
||||
We use the FlashAttention implementation from Phil Tillet a starting point.
|
||||
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
|
||||
|
||||
Changes:
|
||||
- Implement both causal and non-causal attention.
|
||||
- Implement both self-attention and cross-attention.
|
||||
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
|
||||
- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
|
||||
- Support attention bias.
|
||||
- Speed up the forward pass a bit, and only store the LSE instead of m and l.
|
||||
- Make the backward for d=128 much faster by reducing register spilling.
|
||||
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
|
||||
small batch size * nheads.
|
||||
|
||||
Caution:
|
||||
- This is an *experimental* implementation. The forward pass should be quite robust but
|
||||
I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
|
||||
- This implementation has only been tested on A100.
|
||||
- If you plan to use headdim other than 64 and 128, you should test for race conditions
|
||||
(due to the Triton compiler), as done in tests/test_flash_attn.py
|
||||
"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
|
||||
for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
|
||||
that there are none left for other head dimensions.
|
||||
|
||||
Differences between this Triton version and the CUDA version:
|
||||
- Triton version doesn't support dropout.
|
||||
- Triton forward is generally faster than CUDA forward, while Triton backward is
|
||||
generally slower than CUDA backward. Overall Triton forward + backward is slightly slower
|
||||
than CUDA forward + backward.
|
||||
- Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
|
||||
- Triton version supports attention bias, while CUDA version doesn't.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1),
|
||||
# # This config has a race condition when EVEN_M == False, disabling it for now.
|
||||
# # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
|
||||
# ],
|
||||
# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM']
|
||||
# )
|
||||
@triton.heuristics(
|
||||
{
|
||||
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
||||
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
|
||||
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
||||
}
|
||||
)
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, Bias, Out,
|
||||
Lse, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
|
||||
softmax_scale,
|
||||
stride_qb, stride_qh, stride_qm,
|
||||
stride_kb, stride_kh, stride_kn,
|
||||
stride_vb, stride_vh, stride_vn,
|
||||
stride_bb, stride_bh, stride_bm,
|
||||
stride_ob, stride_oh, stride_om,
|
||||
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
|
||||
CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
|
||||
BIAS_TYPE: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
BLOCK_HEADDIM: tl.constexpr,
|
||||
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hb = tl.program_id(1)
|
||||
off_b = off_hb // nheads
|
||||
off_h = off_hb % nheads
|
||||
# off_b = tl.program_id(1)
|
||||
# off_h = tl.program_id(2)
|
||||
# off_hb = off_b * nheads + off_h
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
||||
# Initialize pointers to Q, K, V
|
||||
# Adding parenthesis around indexing might use int32 math instead of int64 math?
|
||||
# https://github.com/openai/triton/issues/741
|
||||
# I'm seeing a tiny bit of difference (5-7us)
|
||||
q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
|
||||
k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
|
||||
v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
|
||||
if BIAS_TYPE == 'vector':
|
||||
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
|
||||
elif BIAS_TYPE == 'matrix':
|
||||
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])
|
||||
# initialize pointer to m and l
|
||||
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
|
||||
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
|
||||
# load q: it will stay in SRAM throughout
|
||||
# [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
|
||||
# tl.load(q_ptrs), we get the wrong output!
|
||||
if EVEN_M & EVEN_N:
|
||||
if EVEN_HEADDIM:
|
||||
q = tl.load(q_ptrs)
|
||||
else:
|
||||
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
||||
else:
|
||||
if EVEN_HEADDIM:
|
||||
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
|
||||
else:
|
||||
q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
||||
other=0.0)
|
||||
# loop over k, v and update accumulator
|
||||
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
|
||||
for start_n in range(0, end_n, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
|
||||
if EVEN_HEADDIM:
|
||||
k = tl.load(k_ptrs + start_n * stride_kn)
|
||||
else:
|
||||
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
|
||||
else:
|
||||
if EVEN_HEADDIM:
|
||||
k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k,
|
||||
other=0.0)
|
||||
else:
|
||||
k = tl.load(k_ptrs + start_n * stride_kn,
|
||||
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
|
||||
other=0.0)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k, trans_b=True)
|
||||
# Trying to combine the two masks seem to make the result wrong
|
||||
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
||||
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
|
||||
if IS_CAUSAL:
|
||||
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
|
||||
if BIAS_TYPE != 'none':
|
||||
if BIAS_TYPE == 'vector':
|
||||
if EVEN_N:
|
||||
bias = tl.load(b_ptrs + start_n).to(tl.float32)
|
||||
else:
|
||||
bias = tl.load(b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0).to(tl.float32)
|
||||
bias = bias[None, :]
|
||||
elif BIAS_TYPE == 'matrix':
|
||||
if EVEN_M & EVEN_N:
|
||||
bias = tl.load(b_ptrs + start_n).to(tl.float32)
|
||||
else:
|
||||
bias = tl.load(b_ptrs + start_n,
|
||||
mask=(offs_m[:, None] < seqlen_q)
|
||||
& ((start_n + offs_n)[None, :] < seqlen_k),
|
||||
other=0.0).to(tl.float32)
|
||||
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
|
||||
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
|
||||
# to multiply with softmax_scale here.
|
||||
qk = qk * softmax_scale + bias
|
||||
m_ij = tl.maximum(tl.max(qk, 1), lse_i)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
else:
|
||||
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
|
||||
p = tl.exp(qk * softmax_scale - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
|
||||
# scale acc_o
|
||||
acc_o_scale = tl.exp(m_i - m_ij)
|
||||
|
||||
# # -- update output accumulator --
|
||||
# BUG: have to store and immediately load
|
||||
tl.store(t_ptrs, acc_o_scale)
|
||||
acc_o_scale = tl.load(t_ptrs)
|
||||
acc_o = acc_o * acc_o_scale[:, None]
|
||||
# update acc_o
|
||||
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
|
||||
if EVEN_HEADDIM:
|
||||
v = tl.load(v_ptrs + start_n * stride_vn)
|
||||
else:
|
||||
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
|
||||
else:
|
||||
if EVEN_HEADDIM:
|
||||
v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k,
|
||||
other=0.0)
|
||||
else:
|
||||
v = tl.load(v_ptrs + start_n * stride_vn,
|
||||
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
|
||||
other=0.0)
|
||||
p = p.to(v.dtype)
|
||||
acc_o += tl.dot(p, v)
|
||||
|
||||
# -- update statistics
|
||||
m_i = m_ij
|
||||
l_i_new = tl.exp(lse_i - m_ij) + l_ij
|
||||
lse_i = m_ij + tl.log(l_i_new)
|
||||
|
||||
o_scale = tl.exp(m_i - lse_i)
|
||||
# BUG: have to store and immediately load
|
||||
tl.store(t_ptrs, o_scale)
|
||||
o_scale = tl.load(t_ptrs)
|
||||
acc_o = acc_o * o_scale[:, None]
|
||||
# rematerialize offsets to save registers
|
||||
start_m = tl.program_id(0)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
# write back l and m
|
||||
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
|
||||
tl.store(lse_ptrs, lse_i)
|
||||
# initialize pointers to output
|
||||
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
||||
out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :])
|
||||
if EVEN_M:
|
||||
if EVEN_HEADDIM:
|
||||
tl.store(out_ptrs, acc_o)
|
||||
else:
|
||||
tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
|
||||
else:
|
||||
if EVEN_HEADDIM:
|
||||
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
|
||||
else:
|
||||
tl.store(out_ptrs, acc_o,
|
||||
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_preprocess_do_o_dot(
|
||||
Out, DO, Delta,
|
||||
stride_ob, stride_oh, stride_om,
|
||||
stride_dob, stride_doh, stride_dom,
|
||||
nheads, seqlen_q, seqlen_q_rounded, headdim,
|
||||
BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hb = tl.program_id(1)
|
||||
off_b = off_hb // nheads
|
||||
off_h = off_hb % nheads
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
||||
# load
|
||||
o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],
|
||||
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32)
|
||||
do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :],
|
||||
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32)
|
||||
delta = tl.sum(o * do, axis=1)
|
||||
# write-back
|
||||
tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_store_dk_dv(
|
||||
dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
|
||||
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
|
||||
):
|
||||
# [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
|
||||
# if we just call tl.store(dv_ptrs), there's a race condition
|
||||
if EVEN_N & EVEN_M:
|
||||
if EVEN_HEADDIM:
|
||||
tl.store(dv_ptrs, dv)
|
||||
tl.store(dk_ptrs, dk)
|
||||
else:
|
||||
tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
|
||||
tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
|
||||
else:
|
||||
if EVEN_HEADDIM:
|
||||
tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
|
||||
tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
|
||||
else:
|
||||
tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
|
||||
tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_kernel_one_col_block(
|
||||
start_n,
|
||||
Q, K, V, Bias,
|
||||
DO, DQ, DK, DV,
|
||||
LSE, D,
|
||||
softmax_scale,
|
||||
stride_qm, stride_kn, stride_vn, stride_bm,
|
||||
stride_dom, stride_dqm, stride_dkn, stride_dvn,
|
||||
seqlen_q, seqlen_k, headdim,
|
||||
ATOMIC_ADD: tl.constexpr,
|
||||
BIAS_TYPE: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
BLOCK_HEADDIM: tl.constexpr,
|
||||
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
|
||||
):
|
||||
# We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
|
||||
begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
|
||||
# initialize row/col offsets
|
||||
offs_qm = begin_m + tl.arange(0, BLOCK_M)
|
||||
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
||||
# initialize pointers to value-like data
|
||||
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
|
||||
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
|
||||
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
|
||||
do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
|
||||
dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
|
||||
if BIAS_TYPE == 'vector':
|
||||
b_ptrs = Bias + offs_n
|
||||
elif BIAS_TYPE == 'matrix':
|
||||
b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
|
||||
# initialize dv and dk
|
||||
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
||||
# There seems to be some problem with Triton pipelining that makes results wrong for
|
||||
# headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop
|
||||
# may have zero step, and pipelining with the bias matrix could screw it up.
|
||||
# So we just exit early.
|
||||
if begin_m >= seqlen_q:
|
||||
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
|
||||
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
|
||||
_bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
|
||||
EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
|
||||
return
|
||||
# k and v stay in SRAM throughout
|
||||
# [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
|
||||
# if we just call tl.load(k_ptrs), we get the wrong output!
|
||||
if EVEN_N & EVEN_M:
|
||||
if EVEN_HEADDIM:
|
||||
k = tl.load(k_ptrs)
|
||||
v = tl.load(v_ptrs)
|
||||
else:
|
||||
k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
||||
v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
||||
else:
|
||||
if EVEN_HEADDIM:
|
||||
k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
|
||||
v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
|
||||
else:
|
||||
k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
|
||||
other=0.0)
|
||||
v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
|
||||
other=0.0)
|
||||
# loop over rows
|
||||
num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
|
||||
for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
|
||||
start_m = tl.multiple_of(start_m, BLOCK_M)
|
||||
offs_m_curr = start_m + offs_m
|
||||
# load q, k, v, do on-chip
|
||||
# Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117)
|
||||
if EVEN_M & EVEN_HEADDIM:
|
||||
q = tl.load(q_ptrs)
|
||||
else:
|
||||
if EVEN_HEADDIM:
|
||||
q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
|
||||
else:
|
||||
q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
|
||||
& (offs_d[None, :] < headdim), other=0.0)
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
qk = tl.dot(q, k, trans_b=True)
|
||||
# Trying to combine the two masks seem to make the result wrong
|
||||
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
||||
qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
|
||||
if IS_CAUSAL:
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
||||
if BIAS_TYPE != 'none':
|
||||
tl.debug_barrier() # Race condition otherwise
|
||||
if BIAS_TYPE == 'vector':
|
||||
if EVEN_N:
|
||||
bias = tl.load(b_ptrs).to(tl.float32)
|
||||
else:
|
||||
bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)
|
||||
bias = bias[None, :]
|
||||
elif BIAS_TYPE == 'matrix':
|
||||
if EVEN_M & EVEN_N:
|
||||
bias = tl.load(b_ptrs).to(tl.float32)
|
||||
else:
|
||||
bias = tl.load(b_ptrs,
|
||||
mask=(offs_m_curr[:, None] < seqlen_q)
|
||||
& (offs_n[None, :] < seqlen_k),
|
||||
other=0.0).to(tl.float32)
|
||||
qk = qk * softmax_scale + bias
|
||||
# There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
|
||||
# Also wrong for headdim=64.
|
||||
if not (EVEN_M & EVEN_HEADDIM):
|
||||
tl.debug_barrier()
|
||||
lse_i = tl.load(LSE + offs_m_curr)
|
||||
if BIAS_TYPE == 'none':
|
||||
p = tl.exp(qk * softmax_scale - lse_i[:, None])
|
||||
else:
|
||||
p = tl.exp(qk - lse_i[:, None])
|
||||
# compute dv
|
||||
# [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
|
||||
# do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs
|
||||
# in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512,
|
||||
# the output is correct.
|
||||
if EVEN_M & EVEN_HEADDIM:
|
||||
do = tl.load(do_ptrs)
|
||||
else:
|
||||
# [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
|
||||
do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
|
||||
& (offs_d[None, :] < headdim), other=0.0)
|
||||
# if EVEN_M:
|
||||
# if EVEN_HEADDIM:
|
||||
# do = tl.load(do_ptrs)
|
||||
# else:
|
||||
# do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
||||
# else:
|
||||
# if EVEN_HEADDIM:
|
||||
# do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
|
||||
# else:
|
||||
# do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
|
||||
# & (offs_d[None, :] < headdim), other=0.0)
|
||||
dv += tl.dot(p.to(do.dtype), do, trans_a=True)
|
||||
# compute dp = dot(v, do)
|
||||
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
|
||||
# Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
|
||||
# Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
|
||||
if not (EVEN_M & EVEN_HEADDIM):
|
||||
tl.debug_barrier()
|
||||
dp = tl.dot(do, v, trans_b=True)
|
||||
# There's a race condition for headdim=48
|
||||
if not EVEN_HEADDIM:
|
||||
tl.debug_barrier()
|
||||
# compute ds = p * (dp - delta[:, None])
|
||||
# Putting the subtraction after the dp matmul (instead of before) is slightly faster
|
||||
Di = tl.load(D + offs_m_curr)
|
||||
# Converting ds to q.dtype here reduces register pressure and makes it much faster
|
||||
# for BLOCK_HEADDIM=128
|
||||
ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
|
||||
# compute dk = dot(ds.T, q)
|
||||
dk += tl.dot(ds, q, trans_a=True)
|
||||
# compute dq
|
||||
if not (EVEN_M & EVEN_HEADDIM): # Otherewise there's a race condition when BIAS_TYPE='matrix'
|
||||
tl.debug_barrier()
|
||||
if not ATOMIC_ADD:
|
||||
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
|
||||
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
|
||||
dq += tl.dot(ds, k)
|
||||
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
|
||||
else:
|
||||
if EVEN_HEADDIM:
|
||||
dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0,
|
||||
eviction_policy="evict_last")
|
||||
dq += tl.dot(ds, k)
|
||||
tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q,
|
||||
eviction_policy="evict_last")
|
||||
else:
|
||||
dq = tl.load(dq_ptrs,
|
||||
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
||||
other=0.0, eviction_policy="evict_last")
|
||||
dq += tl.dot(ds, k)
|
||||
tl.store(dq_ptrs, dq,
|
||||
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
||||
eviction_policy="evict_last")
|
||||
else: # If we're parallelizing across the seqlen_k dimension
|
||||
dq = tl.dot(ds, k)
|
||||
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
|
||||
tl.atomic_add(dq_ptrs, dq)
|
||||
else:
|
||||
if EVEN_HEADDIM:
|
||||
tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
|
||||
else:
|
||||
tl.atomic_add(dq_ptrs, dq,
|
||||
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
|
||||
# increment pointers
|
||||
dq_ptrs += BLOCK_M * stride_dqm
|
||||
q_ptrs += BLOCK_M * stride_qm
|
||||
do_ptrs += BLOCK_M * stride_dom
|
||||
if BIAS_TYPE == 'matrix':
|
||||
b_ptrs += BLOCK_M * stride_bm
|
||||
# write-back
|
||||
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
|
||||
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
|
||||
_bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
|
||||
EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
|
||||
|
||||
|
||||
def init_to_zero(name):
|
||||
return lambda nargs: nargs[name].zero_()
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
|
||||
# Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now
|
||||
# # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
|
||||
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
|
||||
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
|
||||
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
|
||||
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
|
||||
],
|
||||
key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'],
|
||||
)
|
||||
@triton.heuristics(
|
||||
{
|
||||
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
||||
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
|
||||
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
||||
}
|
||||
)
|
||||
@triton.jit
|
||||
def _bwd_kernel(
|
||||
Q, K, V, Bias,
|
||||
DO, DQ, DK, DV,
|
||||
LSE, D,
|
||||
softmax_scale,
|
||||
stride_qb, stride_qh, stride_qm,
|
||||
stride_kb, stride_kh, stride_kn,
|
||||
stride_vb, stride_vh, stride_vn,
|
||||
stride_bb, stride_bh, stride_bm,
|
||||
stride_dob, stride_doh, stride_dom,
|
||||
stride_dqb, stride_dqh, stride_dqm,
|
||||
stride_dkb, stride_dkh, stride_dkn,
|
||||
stride_dvb, stride_dvh, stride_dvn,
|
||||
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
|
||||
CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
|
||||
BIAS_TYPE: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
BLOCK_HEADDIM: tl.constexpr,
|
||||
SEQUENCE_PARALLEL: tl.constexpr,
|
||||
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
|
||||
):
|
||||
off_hb = tl.program_id(1)
|
||||
off_b = off_hb // nheads
|
||||
off_h = off_hb % nheads
|
||||
# offset pointers for batch/head
|
||||
Q += off_b * stride_qb + off_h * stride_qh
|
||||
K += off_b * stride_kb + off_h * stride_kh
|
||||
V += off_b * stride_vb + off_h * stride_vh
|
||||
DO += off_b * stride_dob + off_h * stride_doh
|
||||
DQ += off_b * stride_dqb + off_h * stride_dqh
|
||||
DK += off_b * stride_dkb + off_h * stride_dkh
|
||||
DV += off_b * stride_dvb + off_h * stride_dvh
|
||||
if BIAS_TYPE != 'none':
|
||||
Bias += off_b * stride_bb + off_h * stride_bh
|
||||
# pointer to row-wise quantities in value-like data
|
||||
D += off_hb * seqlen_q_rounded
|
||||
LSE += off_hb * seqlen_q_rounded
|
||||
if not SEQUENCE_PARALLEL:
|
||||
num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
|
||||
for start_n in range(0, num_block_n):
|
||||
_bwd_kernel_one_col_block(
|
||||
start_n,
|
||||
Q, K, V, Bias,
|
||||
DO, DQ, DK, DV,
|
||||
LSE, D,
|
||||
softmax_scale,
|
||||
stride_qm, stride_kn, stride_vn, stride_bm,
|
||||
stride_dom, stride_dqm, stride_dkn, stride_dvn,
|
||||
seqlen_q, seqlen_k, headdim,
|
||||
ATOMIC_ADD=False,
|
||||
BIAS_TYPE=BIAS_TYPE,
|
||||
IS_CAUSAL=IS_CAUSAL,
|
||||
BLOCK_HEADDIM=BLOCK_HEADDIM,
|
||||
EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM,
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
|
||||
)
|
||||
else:
|
||||
start_n = tl.program_id(0)
|
||||
_bwd_kernel_one_col_block(
|
||||
start_n,
|
||||
Q, K, V, Bias,
|
||||
DO, DQ, DK, DV,
|
||||
LSE, D,
|
||||
softmax_scale,
|
||||
stride_qm, stride_kn, stride_vn, stride_bm,
|
||||
stride_dom, stride_dqm, stride_dkn, stride_dvn,
|
||||
seqlen_q, seqlen_k, headdim,
|
||||
ATOMIC_ADD=True,
|
||||
BIAS_TYPE=BIAS_TYPE,
|
||||
IS_CAUSAL=IS_CAUSAL,
|
||||
BLOCK_HEADDIM=BLOCK_HEADDIM,
|
||||
EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM,
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
|
||||
)
|
||||
|
||||
|
||||
def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
|
||||
# shape constraints
|
||||
batch, seqlen_q, nheads, d = q.shape
|
||||
_, seqlen_k, _, _ = k.shape
|
||||
assert k.shape == (batch, seqlen_k, nheads, d)
|
||||
assert v.shape == (batch, seqlen_k, nheads, d)
|
||||
assert d <= 128, 'FlashAttention only support head dimensions up to 128'
|
||||
assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'
|
||||
assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16'
|
||||
assert q.is_cuda and k.is_cuda and v.is_cuda
|
||||
softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
|
||||
|
||||
has_bias = bias is not None
|
||||
bias_type = 'none'
|
||||
if has_bias:
|
||||
assert bias.dtype in [q.dtype, torch.float]
|
||||
assert bias.is_cuda
|
||||
assert bias.dim() == 4
|
||||
if bias.stride(-1) != 1:
|
||||
bias = bias.contiguous()
|
||||
if bias.shape[2:] == (1, seqlen_k):
|
||||
bias_type = 'vector'
|
||||
elif bias.shape[2:] == (seqlen_q, seqlen_k):
|
||||
bias_type = 'matrix'
|
||||
else:
|
||||
raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
|
||||
' or (seqlen_q, seqlen_k)')
|
||||
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
|
||||
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
|
||||
|
||||
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
|
||||
lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
|
||||
tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
|
||||
o = torch.empty_like(q)
|
||||
|
||||
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
|
||||
BLOCK = 128
|
||||
num_warps = 4 if d <= 64 else 8
|
||||
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, bias, o,
|
||||
lse, tmp,
|
||||
softmax_scale,
|
||||
q.stride(0), q.stride(2), q.stride(1),
|
||||
k.stride(0), k.stride(2), k.stride(1),
|
||||
v.stride(0), v.stride(2), v.stride(1),
|
||||
*bias_strides,
|
||||
o.stride(0), o.stride(2), o.stride(1),
|
||||
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,
|
||||
seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations)
|
||||
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
|
||||
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
|
||||
bias_type, causal, BLOCK_HEADDIM,
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return o, lse, softmax_scale # softmax_scale could have been updated
|
||||
|
||||
|
||||
def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):
|
||||
# Make sure that the last dimension is contiguous
|
||||
if do.stride(-1) != 1:
|
||||
do = do.contiguous()
|
||||
batch, seqlen_q, nheads, d = q.shape
|
||||
_, seqlen_k, _, _ = k.shape
|
||||
# assert d in {16, 32, 64, 128}
|
||||
assert d <= 128
|
||||
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
|
||||
assert lse.shape == (batch, nheads, seqlen_q_rounded)
|
||||
assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
|
||||
assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
|
||||
softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
|
||||
# dq_accum = torch.zeros_like(q, dtype=torch.float32)
|
||||
dq_accum = torch.empty_like(q, dtype=torch.float32)
|
||||
delta = torch.empty_like(lse)
|
||||
# delta = torch.zeros_like(lse)
|
||||
|
||||
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
|
||||
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
|
||||
_bwd_preprocess_do_o_dot[grid](
|
||||
o, do, delta,
|
||||
o.stride(0), o.stride(2), o.stride(1),
|
||||
do.stride(0), do.stride(2), do.stride(1),
|
||||
nheads, seqlen_q, seqlen_q_rounded, d,
|
||||
BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM,
|
||||
)
|
||||
|
||||
has_bias = bias is not None
|
||||
bias_type = 'none'
|
||||
if has_bias:
|
||||
assert bias.dtype in [q.dtype, torch.float]
|
||||
assert bias.is_cuda
|
||||
assert bias.dim() == 4
|
||||
assert bias.stride(-1) == 1
|
||||
if bias.shape[2:] == (1, seqlen_k):
|
||||
bias_type = 'vector'
|
||||
elif bias.shape[2:] == (seqlen_q, seqlen_k):
|
||||
bias_type = 'matrix'
|
||||
else:
|
||||
raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
|
||||
' or (seqlen_q, seqlen_k)')
|
||||
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
|
||||
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
|
||||
|
||||
# BLOCK_M = 128
|
||||
# BLOCK_N = 64
|
||||
# num_warps = 4
|
||||
grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1,
|
||||
batch * nheads)
|
||||
_bwd_kernel[grid](
|
||||
q, k, v, bias,
|
||||
do, dq_accum, dk, dv,
|
||||
lse, delta,
|
||||
softmax_scale,
|
||||
q.stride(0), q.stride(2), q.stride(1),
|
||||
k.stride(0), k.stride(2), k.stride(1),
|
||||
v.stride(0), v.stride(2), v.stride(1),
|
||||
*bias_strides,
|
||||
do.stride(0), do.stride(2), do.stride(1),
|
||||
dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1),
|
||||
dk.stride(0), dk.stride(2), dk.stride(1),
|
||||
dv.stride(0), dv.stride(2), dv.stride(1),
|
||||
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,
|
||||
seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations)
|
||||
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
|
||||
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
|
||||
bias_type, causal, BLOCK_HEADDIM,
|
||||
# SEQUENCE_PARALLEL=False,
|
||||
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
|
||||
# num_warps=num_warps,
|
||||
# num_stages=1,
|
||||
)
|
||||
dq.copy_(dq_accum)
|
||||
|
||||
|
||||
class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
|
||||
"""
|
||||
qkv: (batch, seqlen, 3, nheads, headdim)
|
||||
bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
|
||||
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
|
||||
ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
|
||||
"""
|
||||
# Make sure that the last dimension is contiguous
|
||||
if qkv.stride(-1) != 1:
|
||||
qkv = qkv.contiguous()
|
||||
o, lse, ctx.softmax_scale = _flash_attn_forward(
|
||||
qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal,
|
||||
softmax_scale=softmax_scale
|
||||
)
|
||||
ctx.save_for_backward(qkv, o, lse, bias)
|
||||
ctx.causal = causal
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
qkv, o, lse, bias = ctx.saved_tensors
|
||||
assert not ctx.needs_input_grad[1], 'FlashAttention does not support bias gradient yet'
|
||||
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
|
||||
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
|
||||
with torch.inference_mode():
|
||||
dqkv = torch.empty_like(qkv)
|
||||
_flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse,
|
||||
dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
|
||||
bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
|
||||
return dqkv, None, None, None
|
||||
|
||||
|
||||
flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
|
||||
|
||||
|
||||
class FlashAttnKVPackedFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
|
||||
"""
|
||||
q: (batch, seqlen_q, nheads, headdim)
|
||||
kv: (batch, seqlen_k, 2, nheads, headdim)
|
||||
bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
|
||||
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
|
||||
ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
|
||||
"""
|
||||
# Make sure that the last dimension is contiguous
|
||||
q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]
|
||||
o, lse, ctx.softmax_scale = _flash_attn_forward(
|
||||
q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale
|
||||
)
|
||||
ctx.save_for_backward(q, kv, o, lse, bias)
|
||||
ctx.causal = causal
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
q, kv, o, lse, bias = ctx.saved_tensors
|
||||
if len(ctx.needs_input_grad) >= 3:
|
||||
assert not ctx.needs_input_grad[2], 'FlashAttention does not support bias gradient yet'
|
||||
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
|
||||
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
|
||||
with torch.inference_mode():
|
||||
dq = torch.empty_like(q)
|
||||
dkv = torch.empty_like(kv)
|
||||
_flash_attn_backward(do, q, kv[:, :, 0], kv[:, :, 1], o, lse,
|
||||
dq, dkv[:, :, 0], dkv[:, :, 1],
|
||||
bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
|
||||
return dq, dkv, None, None, None
|
||||
|
||||
|
||||
flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
|
||||
|
||||
|
||||
class FlashAttnFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
|
||||
"""
|
||||
q: (batch_size, seqlen_q, nheads, headdim)
|
||||
k, v: (batch_size, seqlen_k, nheads, headdim)
|
||||
bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
|
||||
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
|
||||
ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
|
||||
"""
|
||||
# Make sure that the last dimension is contiguous
|
||||
q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
|
||||
o, lse, ctx.softmax_scale = _flash_attn_forward(
|
||||
q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, o, lse, bias)
|
||||
ctx.causal = causal
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
q, k, v, o, lse, bias = ctx.saved_tensors
|
||||
assert not ctx.needs_input_grad[3], 'FlashAttention does not support bias gradient yet'
|
||||
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
|
||||
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
|
||||
with torch.inference_mode():
|
||||
dq = torch.empty_like(q)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
_flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv,
|
||||
bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
|
||||
return dq, dk, dv, None, None, None
|
||||
|
||||
|
||||
flash_attn_func = FlashAttnFunc.apply
|
||||
276
pkgs/xformers/_flash_attn/flash_attn_triton_og.py
Normal file
276
pkgs/xformers/_flash_attn/flash_attn_triton_og.py
Normal file
@@ -0,0 +1,276 @@
|
||||
# [2022-10-23] Downloaded from https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
|
||||
# for benchmarking.
|
||||
# We fixed a few dtype cast to make it work for bf16
|
||||
|
||||
"""
|
||||
Fused Attention
|
||||
===============
|
||||
This is a Triton implementation of the Flash Attention algorithm
|
||||
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, H, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
|
||||
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
# Initialize pointers to Q, K, V
|
||||
q_ptrs = Q + off_q
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
# initialize pointer to m and l
|
||||
t_ptrs = TMP + off_hz * N_CTX + offs_m
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# load q: it will stay in SRAM throughout
|
||||
q = tl.load(q_ptrs)
|
||||
# loop over k, v and update accumulator
|
||||
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs + start_n * stride_kn)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k, trans_b=True)
|
||||
qk *= sm_scale
|
||||
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
tl.store(t_ptrs, acc_scale)
|
||||
acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(v_ptrs + start_n * stride_vk)
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
# rematerialize offsets to save registers
|
||||
start_m = tl.program_id(0)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
# write back l and m
|
||||
l_ptrs = L + off_hz * N_CTX + offs_m
|
||||
m_ptrs = M + off_hz * N_CTX + offs_m
|
||||
tl.store(l_ptrs, l_i)
|
||||
tl.store(m_ptrs, m_i)
|
||||
# initialize pointers to output
|
||||
offs_n = tl.arange(0, BLOCK_DMODEL)
|
||||
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_preprocess(
|
||||
Out, DO, L,
|
||||
NewDO, Delta,
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
||||
):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, D_HEAD)
|
||||
# load
|
||||
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
denom = tl.load(L + off_m).to(tl.float32)
|
||||
# compute
|
||||
do = do / denom[:, None]
|
||||
delta = tl.sum(o * do, axis=1)
|
||||
# write-back
|
||||
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
|
||||
tl.store(Delta + off_m, delta)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_kernel(
|
||||
Q, K, V, sm_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L, M,
|
||||
D,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
Z, H, N_CTX,
|
||||
num_block,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
off_hz = tl.program_id(0)
|
||||
off_z = off_hz // H
|
||||
off_h = off_hz % H
|
||||
# offset pointers for batch/head
|
||||
Q += off_z * stride_qz + off_h * stride_qh
|
||||
K += off_z * stride_qz + off_h * stride_qh
|
||||
V += off_z * stride_qz + off_h * stride_qh
|
||||
DO += off_z * stride_qz + off_h * stride_qh
|
||||
DQ += off_z * stride_qz + off_h * stride_qh
|
||||
DK += off_z * stride_qz + off_h * stride_qh
|
||||
DV += off_z * stride_qz + off_h * stride_qh
|
||||
for start_n in range(0, num_block):
|
||||
lo = start_n * BLOCK_M
|
||||
# initialize row/col offsets
|
||||
offs_qm = lo + tl.arange(0, BLOCK_M)
|
||||
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_m = tl.arange(0, BLOCK_N)
|
||||
offs_k = tl.arange(0, BLOCK_DMODEL)
|
||||
# initialize pointers to value-like data
|
||||
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
# pointer to row-wise quantities in value-like data
|
||||
D_ptrs = D + off_hz * N_CTX
|
||||
m_ptrs = M + off_hz * N_CTX
|
||||
# initialize dv amd dk
|
||||
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# k and v stay in SRAM throughout
|
||||
k = tl.load(k_ptrs)
|
||||
v = tl.load(v_ptrs)
|
||||
# loop over rows
|
||||
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
|
||||
offs_m_curr = start_m + offs_m
|
||||
# load q, k, v, do on-chip
|
||||
q = tl.load(q_ptrs)
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
# NOTE: `do` is pre-divided by `l`; no normalization here
|
||||
qk = tl.dot(q, k, trans_b=True)
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
||||
m = tl.load(m_ptrs + offs_m_curr)
|
||||
p = tl.exp(qk * sm_scale - m[:, None])
|
||||
# compute dv
|
||||
do = tl.load(do_ptrs)
|
||||
dv += tl.dot(p.to(do.dtype), do, trans_a=True)
|
||||
# compute dp = dot(v, do)
|
||||
Di = tl.load(D_ptrs + offs_m_curr)
|
||||
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
||||
dp += tl.dot(do, v, trans_b=True)
|
||||
# compute ds = p * (dp - delta[:, None])
|
||||
ds = p * dp * sm_scale
|
||||
# compute dk = dot(ds.T, q)
|
||||
dk += tl.dot(ds.to(q.dtype), q, trans_a=True)
|
||||
# # compute dq
|
||||
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
|
||||
dq += tl.dot(ds.to(k.dtype), k)
|
||||
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
|
||||
# # increment pointers
|
||||
dq_ptrs += BLOCK_M * stride_qm
|
||||
q_ptrs += BLOCK_M * stride_qm
|
||||
do_ptrs += BLOCK_M * stride_qm
|
||||
# write-back
|
||||
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
tl.store(dv_ptrs, dv)
|
||||
tl.store(dk_ptrs, dk)
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, sm_scale):
|
||||
BLOCK = 128
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
o = torch.empty_like(q)
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
|
||||
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
tmp, L, m,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=Lk, num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
ctx.BLOCK = BLOCK
|
||||
ctx.grid = grid
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.BLOCK_DMODEL = Lk
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
q, k, v, o, l, m = ctx.saved_tensors
|
||||
do = do.contiguous()
|
||||
dq = torch.zeros_like(q, dtype=torch.float32)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
do_scaled = torch.empty_like(do)
|
||||
delta = torch.empty_like(l)
|
||||
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
||||
o, do, l,
|
||||
do_scaled, delta,
|
||||
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
|
||||
# NOTE: kernel currently buggy for other values of `num_warps`
|
||||
num_warps = 8
|
||||
_bwd_kernel[(ctx.grid[1],)](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dq, dk, dv,
|
||||
l, m,
|
||||
delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
ctx.grid[0],
|
||||
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return dq.to(q.dtype), dk, dv, None
|
||||
|
||||
|
||||
attention = _attention.apply
|
||||
136
pkgs/xformers/_flash_attn/flash_blocksparse_attention.py
Normal file
136
pkgs/xformers/_flash_attn/flash_blocksparse_attention.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
import hydra
|
||||
|
||||
from flash_attn.flash_blocksparse_attn_interface import flash_blocksparse_attn_func
|
||||
from flash_attn.flash_blocksparse_attn_interface import convert_blockmask
|
||||
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
|
||||
|
||||
|
||||
class FlashBlocksparseAttention(nn.Module):
|
||||
"""Implement the scaled dot product attention with softmax.
|
||||
Arguments
|
||||
---------
|
||||
softmax_temp: The temperature to use for the softmax attention.
|
||||
(default: 1/sqrt(d_keys) where d_keys is computed at
|
||||
runtime)
|
||||
attention_dropout: The dropout rate to apply to the attention
|
||||
(default: 0.1)
|
||||
"""
|
||||
def __init__(self, sparsity_config, softmax_temp=None, attention_dropout=0.0,
|
||||
max_seq_length=2048, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.sparsity_config = hydra.utils.instantiate(sparsity_config)
|
||||
self.softmax_temp = softmax_temp
|
||||
self.dropout_p = attention_dropout
|
||||
|
||||
# initialize sparse layout and register as buffer
|
||||
max_seq_length = ((max_seq_length + 256 - 1) // 256) * 256
|
||||
layout = self.sparsity_config.make_layout(max_seq_length)
|
||||
self.register_buffer("layout", layout)
|
||||
blockmask_converted = convert_blockmask(self.layout, causal=False)
|
||||
self.register_buffer("blockmask_converted", blockmask_converted)
|
||||
# logger.info(f'Attention class {self.__class__}: saving={self.layout.float().mean()}')
|
||||
|
||||
def forward(self, qkv, attn_mask=None, key_padding_mask=None, causal=False, cu_seqlens=None,
|
||||
max_s=None, need_weights=False, convert_mask=True):
|
||||
"""Implements the multihead softmax attention.
|
||||
Arguments
|
||||
---------
|
||||
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
|
||||
attn_mask: An implementation of BaseMask that encodes where each
|
||||
query can attend to
|
||||
key_padding_mask: An implementation of BaseMask that encodes how
|
||||
many query each sequence in the batch consists of
|
||||
"""
|
||||
assert not need_weights
|
||||
assert attn_mask is None
|
||||
assert qkv.dtype == torch.float16
|
||||
assert qkv.is_cuda
|
||||
|
||||
if cu_seqlens is None:
|
||||
batch_size = qkv.shape[0]
|
||||
seqlen = qkv.shape[1]
|
||||
# Convert mask to take a subset
|
||||
seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256
|
||||
assert seqlen_rounded // 16 <= self.layout.shape[0], seqlen_rounded // 256 <= self.layout.shape[1]
|
||||
blockmask = self.layout[:seqlen_rounded // 16, :seqlen_rounded // 256]
|
||||
if key_padding_mask is None:
|
||||
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
|
||||
max_s = seqlen
|
||||
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
|
||||
device=qkv.device)
|
||||
output = flash_blocksparse_attn_func(
|
||||
qkv, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0,
|
||||
max_s, softmax_scale=self.softmax_temp, causal=causal
|
||||
)
|
||||
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
|
||||
else:
|
||||
key_padding_mask_bool = key_padding_mask.bool_matrix
|
||||
nheads = qkv.shape[-2]
|
||||
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
|
||||
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool)
|
||||
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
|
||||
output_unpad = flash_blocksparse_attn_func(
|
||||
x_unpad, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0,
|
||||
max_s, softmax_scale=self.softmax_temp, causal=causal
|
||||
)
|
||||
output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
|
||||
indices, batch_size, seqlen),
|
||||
'b s (h d) -> b s h d', h=nheads)
|
||||
else:
|
||||
assert max_s is not None
|
||||
seqlen = max_s
|
||||
# Convert mask to take a subset
|
||||
seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256
|
||||
assert seqlen_rounded // 16 <= self.layout.shape[0], seqlen_rounded // 256 <= self.layout.shape[1]
|
||||
blockmask = self.layout[:seqlen_rounded // 16, :seqlen_rounded // 256]
|
||||
if convert_mask:
|
||||
output = flash_blocksparse_attn_func(
|
||||
qkv, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0,
|
||||
max_s, softmax_scale=self.softmax_temp, causal=causal
|
||||
)
|
||||
else:
|
||||
output = flash_blocksparse_attn_func(
|
||||
qkv, cu_seqlens, self.blockmask_converted, self.dropout_p if self.training else 0.0,
|
||||
max_s, softmax_scale=self.softmax_temp, causal=causal,
|
||||
convert_mask=False,
|
||||
)
|
||||
|
||||
return output, None
|
||||
|
||||
|
||||
class FlashBlocksparseMHA(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, num_heads, sparsity_config, bias=True, batch_first=True,
|
||||
attention_dropout=0.0, causal=False, max_seq_length=2048,
|
||||
device=None, dtype=None, **kwargs) -> None:
|
||||
assert batch_first
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.causal = causal
|
||||
|
||||
self.num_heads = num_heads
|
||||
assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
|
||||
self.head_dim = self.embed_dim // num_heads
|
||||
assert self.head_dim in [16, 32, 64], "Only support head_dim == 16, 32, or 64"
|
||||
|
||||
self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
|
||||
self.inner_attn = FlashBlocksparseAttention(
|
||||
sparsity_config, attention_dropout=attention_dropout,
|
||||
max_seq_length=max_seq_length, **factory_kwargs
|
||||
)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
||||
|
||||
def forward(self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None,
|
||||
need_weights=False):
|
||||
qkv = self.Wqkv(x)
|
||||
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
|
||||
context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask,
|
||||
need_weights=need_weights, causal=self.causal)
|
||||
return self.out_proj(rearrange(context, 'b s h d -> b s (h d)')), attn_weights
|
||||
142
pkgs/xformers/_flash_attn/flash_blocksparse_attn_interface.py
Normal file
142
pkgs/xformers/_flash_attn/flash_blocksparse_attn_interface.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import flash_attn_cuda
|
||||
|
||||
|
||||
def convert_blockmask(blockmask, causal):
|
||||
"""Convert from the 0-1 format to the format used by the CUDA code.
|
||||
0 means the block is skipped.
|
||||
nonzero means the block is not skipped.
|
||||
Argument:
|
||||
blockmask: (row, col): a 0-1 tensor
|
||||
Return:
|
||||
blockmask_converted: (col, row), dtype torch.int32: for each column, it contains the row
|
||||
indices of the nonzero blocks, padded with -1 to reach length @row.
|
||||
The indices are multiplied by 4, with the smallest bit used to encode whether
|
||||
it is the first nonzero in its row, and the 2nd smallest bit to encode whether it is
|
||||
the last nonzero in its row..
|
||||
"""
|
||||
assert not causal
|
||||
# TD [2022-05-13]: The indexing and sorting is very tricky
|
||||
nrow, ncol = blockmask.shape
|
||||
# Sort does not support bool on CUDA
|
||||
blockmask = blockmask.to(dtype=torch.uint8)
|
||||
nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=0, stable=True, descending=True)
|
||||
nonzero_unsorted_rowidx = nonzero_sorted_rowidx.argsort(dim=0)
|
||||
last_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True).indices[:, -1]
|
||||
last_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
|
||||
torch.arange(nrow, device=blockmask.device), last_nonzero_col_per_row
|
||||
]
|
||||
first_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True, descending=True).indices[:, 0]
|
||||
first_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
|
||||
torch.arange(nrow, device=blockmask.device), first_nonzero_col_per_row
|
||||
]
|
||||
nonzero_idx = nonzero_sorted_rowidx * 4
|
||||
nonzero_idx[last_nonzero_col_per_row_after_sort, last_nonzero_col_per_row] += 2
|
||||
nonzero_idx[first_nonzero_col_per_row_after_sort, first_nonzero_col_per_row] += 1
|
||||
nonzero_idx[nonzero_val == 0] = -1
|
||||
return nonzero_idx.T.contiguous().to(dtype=torch.int32)
|
||||
|
||||
|
||||
def _flash_blocksparse_attn_forward(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale,
|
||||
causal, return_softmax):
|
||||
context, softmax_lse, *rest = flash_attn_cuda.fwd_block(qkv, cu_seqlens, blockmask, dropout_p,
|
||||
max_s, softmax_scale, causal,
|
||||
return_softmax, None)
|
||||
# if context.isnan().any() or softmax_lse.isnan().any():
|
||||
# breakpoint()
|
||||
S_dmask = rest[0] if return_softmax else None
|
||||
return context, softmax_lse, S_dmask
|
||||
|
||||
|
||||
def _flash_blocksparse_attn_backward(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, blockmask,
|
||||
dropout_p, max_s, softmax_scale, causal):
|
||||
dqkv, dp, softmax_d = flash_attn_cuda.bwd_block(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens,
|
||||
blockmask, dropout_p, softmax_scale, max_s,
|
||||
causal, None)
|
||||
# if dqkv.isnan().any() or softmax_d.isnan().any():
|
||||
# breakpoint()
|
||||
return dqkv
|
||||
|
||||
|
||||
class FlashBlocksparseAttnFun(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||
if softmax_scale is None:
|
||||
softmax_scale = qkv.shape[-1] ** (-0.5)
|
||||
context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward(
|
||||
qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal=causal,
|
||||
return_softmax=False
|
||||
)
|
||||
ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.max_s = max_s
|
||||
ctx.softmax_scale = softmax_scale
|
||||
ctx.causal = causal
|
||||
return context
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors
|
||||
if rng_state is not None:
|
||||
cur_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
# S_dmask is None, temporarily use another tensor just to get it running
|
||||
dqkv = _flash_blocksparse_attn_backward(
|
||||
dout, qkv, context, context, softmax_lse, cu_seqlens, blockmask, ctx.dropout_p,
|
||||
ctx.max_s, ctx.softmax_scale, ctx.causal
|
||||
)
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dqkv, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
# We duplicate code to return both the output and the softmax for testing
|
||||
# Returning both makes backward a bit slower, so we want to keep using the other version for speed.
|
||||
class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):
|
||||
# Save rng_state because the backward pass is gonna regenerate the dropout mask
|
||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||
if softmax_scale is None:
|
||||
softmax_scale = qkv.shape[-1] ** (-0.5)
|
||||
context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward(
|
||||
qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal=causal,
|
||||
return_softmax=True
|
||||
)
|
||||
ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.max_s = max_s
|
||||
ctx.softmax_scale = softmax_scale
|
||||
ctx.causal = causal
|
||||
return context, S_dmask, softmax_lse
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout, _dS_dmask_ignored, _dsoftmax_sum_ignored):
|
||||
qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors
|
||||
if rng_state is not None:
|
||||
cur_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
dqkv = _flash_blocksparse_attn_backward(
|
||||
dout, qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, ctx.dropout_p,
|
||||
ctx.max_s, ctx.softmax_scale, ctx.causal
|
||||
)
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dqkv, None, None, None, None, None, None
|
||||
|
||||
|
||||
def flash_blocksparse_attn_func(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale=None,
|
||||
causal=False, return_attn_probs=False, convert_mask=True):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
"""
|
||||
func = FlashBlocksparseAttnFun if not return_attn_probs else FlashBlocksparseAttnFunWithS
|
||||
if convert_mask:
|
||||
blockmask = convert_blockmask(blockmask, causal=causal)
|
||||
return func.apply(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal)
|
||||
205
pkgs/xformers/_flash_attn/fused_softmax.py
Normal file
205
pkgs/xformers/_flash_attn/fused_softmax.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# [2022-10-23] Copied from https://github.com/NVIDIA/apex/blob/master/apex/transformer/functional/fused_softmax.py
|
||||
# for benchmarking.
|
||||
# We added support for seqlen=2k and seqlen=4k
|
||||
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from apex._autocast_utils import _cast_if_autocast_enabled
|
||||
from apex.transformer.enums import AttnMaskType
|
||||
|
||||
from fused_softmax_lib import scaled_masked_softmax_forward, scaled_masked_softmax_backward
|
||||
from fused_softmax_lib import scaled_masked_softmax_get_batch_per_block
|
||||
from fused_softmax_lib import scaled_upper_triang_masked_softmax_forward, scaled_upper_triang_masked_softmax_backward
|
||||
|
||||
|
||||
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
|
||||
"""
|
||||
Fused operation which performs following three operations in sequence
|
||||
1. Scale the tensor.
|
||||
2. Apply upper triangular mask (typically used in gpt models).
|
||||
3. Perform softmax.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, scale):
|
||||
scale_t = torch.tensor([scale])
|
||||
softmax_results = scaled_upper_triang_masked_softmax_forward(
|
||||
inputs, scale_t[0]
|
||||
)
|
||||
ctx.save_for_backward(softmax_results, scale_t)
|
||||
return softmax_results
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grads):
|
||||
softmax_results, scale_t = ctx.saved_tensors
|
||||
input_grads = scaled_upper_triang_masked_softmax_backward(
|
||||
output_grads, softmax_results, scale_t[0]
|
||||
)
|
||||
return input_grads, None
|
||||
|
||||
|
||||
def scaled_upper_triang_masked_softmax(inputs, _, scale):
|
||||
b, np, sq, sk = inputs.size()
|
||||
assert sq == sk, "causal mask is only for self attention"
|
||||
# Reshaping input to 3D tensor (attn_batches, sq, sk)
|
||||
inputs = inputs.view(-1, sq, sk)
|
||||
args = _cast_if_autocast_enabled(inputs, scale)
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
probs = ScaledUpperTriangMaskedSoftmax.apply(*args)
|
||||
return probs.view(b, np, sq, sk)
|
||||
|
||||
|
||||
# NOTE (mkozuki): `ScaledMaskedSoftmax` somehow doesn't work well with `torch.cuda.amp.custom_fwd`.
|
||||
# Without `cast_inputs` kwarg, somehow inputs are not cast to dtype used in the autocast context.
|
||||
# So I needed to manually write two `torch.autograd.Function` inheritances.
|
||||
# Fused operation which performs following three operations in sequence
|
||||
# 1. Scale the tensor.
|
||||
# 2. Apply the mask.
|
||||
# 3. Perform softmax.
|
||||
class ScaledMaskedSoftmax(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, mask, scale):
|
||||
scale_t = torch.tensor([scale])
|
||||
softmax_results = scaled_masked_softmax_forward(inputs, mask, scale_t[0])
|
||||
ctx.save_for_backward(softmax_results, scale_t)
|
||||
return softmax_results
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grads):
|
||||
softmax_results, scale_t = ctx.saved_tensors
|
||||
input_grads = scaled_masked_softmax_backward(
|
||||
output_grads, softmax_results, scale_t[0]
|
||||
)
|
||||
return input_grads, None, None
|
||||
|
||||
|
||||
def scaled_masked_softmax(inputs, mask, scale):
|
||||
# input is 4D tensor (b, np, sq, sk)
|
||||
args = _cast_if_autocast_enabled(inputs, mask, scale)
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
return ScaledMaskedSoftmax.apply(*args)
|
||||
|
||||
|
||||
class FusedScaleMaskSoftmax(torch.nn.Module):
|
||||
"""
|
||||
fused operation: scaling + mask + softmax
|
||||
|
||||
Arguments:
|
||||
input_in_fp16: flag to indicate if input in fp16 data format.
|
||||
input_in_bf16: flag to indicate if input in bf16 data format.
|
||||
attn_mask_type: attention mask type (pad or causal)
|
||||
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
|
||||
mask_func: mask function to be applied.
|
||||
softmax_in_fp32: if true, softmax in performed at fp32 precision.
|
||||
scale: scaling factor used in input tensor scaling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_in_fp16,
|
||||
input_in_bf16,
|
||||
attn_mask_type,
|
||||
scaled_masked_softmax_fusion,
|
||||
mask_func,
|
||||
softmax_in_fp32,
|
||||
scale,
|
||||
):
|
||||
super().__init__()
|
||||
self.input_in_fp16 = input_in_fp16
|
||||
self.input_in_bf16 = input_in_bf16
|
||||
if self.input_in_fp16 and self.input_in_bf16:
|
||||
raise RuntimeError(
|
||||
"both fp16 and bf16 flags cannot be active at the same time."
|
||||
)
|
||||
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
|
||||
self.attn_mask_type = attn_mask_type
|
||||
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
|
||||
self.mask_func = mask_func
|
||||
self.softmax_in_fp32 = softmax_in_fp32
|
||||
self.scale = scale
|
||||
|
||||
if not (self.scale is None or softmax_in_fp32):
|
||||
raise RuntimeError("softmax should be in fp32 when scaled")
|
||||
|
||||
if self.scaled_masked_softmax_fusion:
|
||||
if self.attn_mask_type == AttnMaskType.causal:
|
||||
self.fused_softmax_func = scaled_upper_triang_masked_softmax
|
||||
elif self.attn_mask_type == AttnMaskType.padding:
|
||||
self.fused_softmax_func = scaled_masked_softmax
|
||||
else:
|
||||
raise ValueError("Invalid attn_mask_type.")
|
||||
|
||||
def forward(self, input, mask):
|
||||
# [b, np, sq, sk]
|
||||
assert input.dim() == 4
|
||||
|
||||
if self.is_kernel_available(mask, *input.size()):
|
||||
return self.forward_fused_softmax(input, mask)
|
||||
else:
|
||||
return self.forward_torch_softmax(input, mask)
|
||||
|
||||
def is_kernel_available(self, mask, b, np, sq, sk):
|
||||
attn_batches = b * np
|
||||
|
||||
if (
|
||||
self.scaled_masked_softmax_fusion # user want to fuse
|
||||
and self.input_in_float16 # input must be fp16
|
||||
and (
|
||||
self.attn_mask_type == AttnMaskType.causal
|
||||
or (self.attn_mask_type == AttnMaskType.padding and mask is not None)
|
||||
)
|
||||
and 16 < sk <= 8192 # sk must be 16 ~ 8192
|
||||
and sq % 4 == 0 # sq must be divisor of 4
|
||||
and sk % 4 == 0 # sk must be divisor of 4
|
||||
and attn_batches % 4 == 0 # np * b must be divisor of 4
|
||||
):
|
||||
if 0 <= sk <= 8192:
|
||||
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
|
||||
|
||||
if self.attn_mask_type == AttnMaskType.causal:
|
||||
if attn_batches % batch_per_block == 0:
|
||||
return True
|
||||
else:
|
||||
if sq % batch_per_block == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def forward_fused_softmax(self, input, mask):
|
||||
# input.shape = [b, np, sq, sk]
|
||||
scale = self.scale if self.scale is not None else 1.0
|
||||
return self.fused_softmax_func(input, mask, scale)
|
||||
|
||||
def forward_torch_softmax(self, input, mask):
|
||||
if self.input_in_float16 and self.softmax_in_fp32:
|
||||
input = input.float()
|
||||
|
||||
if self.scale is not None:
|
||||
input = input * self.scale
|
||||
mask_output = self.mask_func(input, mask) if mask is not None else input
|
||||
probs = torch.nn.Softmax(dim=-1)(mask_output)
|
||||
|
||||
if self.input_in_float16 and self.softmax_in_fp32:
|
||||
if self.input_in_fp16:
|
||||
probs = probs.half()
|
||||
else:
|
||||
probs = probs.bfloat16()
|
||||
|
||||
return probs
|
||||
|
||||
@staticmethod
|
||||
def get_batch_per_block(sq, sk, b, np):
|
||||
return scaled_masked_softmax_get_batch_per_block(sq, sk, b, np)
|
||||
0
pkgs/xformers/_flash_attn/layers/__init__.py
Normal file
0
pkgs/xformers/_flash_attn/layers/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
56
pkgs/xformers/_flash_attn/layers/patch_embed.py
Normal file
56
pkgs/xformers/_flash_attn/layers/patch_embed.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# We use the same API as https://github.com/rwightman/pytorch-image-models/blob/v0.6.11/timm/models/layers/patch_embed.py
|
||||
# But we use nn.Linear instead of Conv2d and it's about 8x faster.
|
||||
|
||||
from functools import partial
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import _assert
|
||||
from torch.nn.modules.utils import _pair
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
try:
|
||||
from flash_attn.ops.fused_dense import FusedDense
|
||||
except ImportError:
|
||||
FusedDense = None
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" 2D Image to Patch Embedding
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
norm_layer=None,
|
||||
flatten=True,
|
||||
bias=True,
|
||||
fused_bias_fc=False,
|
||||
):
|
||||
super().__init__()
|
||||
img_size = _pair(img_size)
|
||||
patch_size = _pair(patch_size)
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.flatten = flatten
|
||||
if fused_bias_fc and FusedDense is None:
|
||||
raise ImportError('fused_dense is not installed')
|
||||
|
||||
linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDense
|
||||
self.proj = linear_cls(in_chans * patch_size[0] * patch_size[1], embed_dim, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
_, _, H, W = x.shape
|
||||
_assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
|
||||
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
|
||||
x = self.proj(rearrange(x, 'b c (h p1) (w p2) -> b h w (c p1 p2)',
|
||||
p1=self.patch_size[0], p2=self.patch_size[1]))
|
||||
if self.flatten:
|
||||
x = rearrange(x, 'b h w c -> b (h w) c')
|
||||
x = self.norm(x)
|
||||
return x
|
||||
336
pkgs/xformers/_flash_attn/layers/rotary.py
Normal file
336
pkgs/xformers/_flash_attn/layers/rotary.py
Normal file
@@ -0,0 +1,336 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
from typing import Tuple, Optional
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
import rotary_emb
|
||||
|
||||
|
||||
def rotate_half(x, interleaved=False):
|
||||
if not interleaved:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
else:
|
||||
x1, x2 = x[..., ::2], x[..., 1::2]
|
||||
return rearrange(torch.stack((-x2, x1), dim=-1), '... d two -> ... (d two)', two=2)
|
||||
|
||||
|
||||
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
||||
"""
|
||||
x: (batch_size, seqlen, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2)
|
||||
"""
|
||||
ro_dim = cos.shape[-1] * 2
|
||||
assert ro_dim <= x.shape[-1]
|
||||
cos = repeat(cos, 's d -> s 1 (2 d)')
|
||||
sin = repeat(sin, 's d -> s 1 (2 d)')
|
||||
return torch.cat([x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
||||
x[..., ro_dim:]], dim=-1)
|
||||
|
||||
|
||||
class ApplyRotaryEmb(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
|
||||
"""
|
||||
x: (batch_size, seqlen, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2)
|
||||
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
||||
of 1st half and 2nd half (GPT-NeoX style).
|
||||
rotary_dim must be <= headdim
|
||||
Apply rotary embedding to the first rotary_dim of x.
|
||||
"""
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
rotary_seqlen, rotary_dim = cos.shape
|
||||
rotary_dim *= 2
|
||||
assert rotary_dim <= headdim
|
||||
assert seqlen <= rotary_seqlen
|
||||
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
|
||||
x_ro = x[..., :rotary_dim]
|
||||
x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2])
|
||||
out = torch.empty_like(x) if not inplace else x
|
||||
out_ro = out[..., :rotary_dim]
|
||||
if inplace:
|
||||
o1, o2 = x1, x2
|
||||
else:
|
||||
o1, o2 = (out_ro.chunk(2, dim=-1) if not interleaved
|
||||
else (out_ro[..., ::2], out_ro[..., 1::2]))
|
||||
rotary_emb.apply_rotary(x1, x2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
||||
rearrange(sin[:seqlen], 's d -> s 1 d'), o1, o2, False)
|
||||
if not inplace and rotary_dim < headdim:
|
||||
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
|
||||
ctx.save_for_backward(cos, sin)
|
||||
ctx.interleaved = interleaved
|
||||
ctx.inplace = inplace
|
||||
return out if not inplace else x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
cos, sin = ctx.saved_tensors
|
||||
_, seqlen, _, headdim = do.shape
|
||||
rotary_dim = cos.shape[-1]
|
||||
rotary_dim *= 2
|
||||
inplace = ctx.inplace
|
||||
do_ro = do[..., :rotary_dim]
|
||||
do1, do2 = (do_ro.chunk(2, dim=-1) if not ctx.interleaved
|
||||
else (do_ro[..., ::2], do_ro[..., 1::2]))
|
||||
dx = torch.empty_like(do) if not inplace else do
|
||||
if inplace:
|
||||
dx1, dx2 = do1, do2
|
||||
else:
|
||||
dx_ro = dx[..., :rotary_dim]
|
||||
dx1, dx2 = (dx_ro.chunk(2, dim=-1) if not ctx.interleaved
|
||||
else (dx_ro[..., ::2], dx_ro[..., 1::2]))
|
||||
rotary_emb.apply_rotary(do1, do2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
||||
rearrange(sin[:seqlen], 's d -> s 1 d'), dx1, dx2, True)
|
||||
if not inplace and rotary_dim < headdim:
|
||||
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
|
||||
return dx, None, None, None, None
|
||||
|
||||
|
||||
apply_rotary_emb_func = ApplyRotaryEmb.apply
|
||||
|
||||
|
||||
class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
|
||||
"""
|
||||
qkv: (batch_size, seqlen, 3, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2)
|
||||
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
|
||||
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
|
||||
1st half and 2nd half (GPT-NeoX style).
|
||||
rotary_dim must be <= headdim
|
||||
Apply rotary embedding *inplace* to the first rotary_dim of q and k.
|
||||
"""
|
||||
batch, seqlen, three, nheads, headdim = qkv.shape
|
||||
assert three == 3
|
||||
rotary_seqlen, rotary_dim = cos.shape
|
||||
rotary_dim *= 2
|
||||
assert rotary_dim <= headdim
|
||||
assert seqlen <= rotary_seqlen
|
||||
cos_k = cos if cos_k is None else cos_k
|
||||
sin_k = sin if sin_k is None else sin_k
|
||||
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
|
||||
q_ro = qkv[:, :, 0, :, :rotary_dim]
|
||||
q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2])
|
||||
rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
||||
rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)
|
||||
k_ro = qkv[:, :, 1, :, :rotary_dim]
|
||||
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
|
||||
rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
|
||||
rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False)
|
||||
ctx.save_for_backward(cos, sin, cos_k, sin_k)
|
||||
ctx.interleaved = interleaved
|
||||
return qkv
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dqkv):
|
||||
cos, sin, cos_k, sin_k = ctx.saved_tensors
|
||||
_, seqlen, _, _, headdim = dqkv.shape
|
||||
rotary_dim = cos.shape[-1]
|
||||
rotary_dim *= 2
|
||||
dq_ro = dqkv[:, :, 0, :, :rotary_dim]
|
||||
dq1, dq2 = (dq_ro.chunk(2, dim=-1) if not ctx.interleaved
|
||||
else (dq_ro[..., ::2], dq_ro[..., 1::2]))
|
||||
rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
||||
rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True)
|
||||
dk_ro = dqkv[:, :, 1, :, :rotary_dim]
|
||||
dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved
|
||||
else (dk_ro[..., ::2], dk_ro[..., 1::2]))
|
||||
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
|
||||
rearrange(sin_k[:seqlen], 's d -> s 1 d'), dk1, dk2, True)
|
||||
return dqkv, None, None, None, None, None
|
||||
|
||||
|
||||
apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
|
||||
|
||||
|
||||
class ApplyRotaryEmbKV_(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, kv, cos, sin, interleaved=False):
|
||||
"""
|
||||
kv: (batch_size, seqlen, 2, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2)
|
||||
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
|
||||
1st half and 2nd half (GPT-NeoX style).
|
||||
rotary_dim must be <= headdim
|
||||
Apply rotary embedding *inplace* to the first rotary_dim of k.
|
||||
"""
|
||||
batch, seqlen, two, nheads, headdim = kv.shape
|
||||
assert two == 2
|
||||
rotary_seqlen, rotary_dim = cos.shape
|
||||
rotary_dim *= 2
|
||||
assert rotary_dim <= headdim
|
||||
assert seqlen <= rotary_seqlen
|
||||
k_ro = kv[:, :, 0, :, :rotary_dim]
|
||||
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
|
||||
rotary_emb.apply_rotary(k1, k2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
||||
rearrange(sin[:seqlen], 's d -> s 1 d'), k1, k2,
|
||||
False) # conj=False since this is the forward pass
|
||||
ctx.save_for_backward(cos, sin)
|
||||
ctx.interleaved = interleaved
|
||||
return kv
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dkv):
|
||||
cos, sin = ctx.saved_tensors
|
||||
_, seqlen, _, _, headdim = dkv.shape
|
||||
rotary_dim = cos.shape[-1]
|
||||
rotary_dim *= 2
|
||||
dk_ro = dkv[:, :, 0, :, :rotary_dim]
|
||||
dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved
|
||||
else (dk_ro[..., ::2], dk_ro[..., 1::2]))
|
||||
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
||||
rearrange(sin[:seqlen], 's d -> s 1 d'), dk1, dk2,
|
||||
True) # conj=True since this is the backward pass
|
||||
return dkv, None, None, None
|
||||
|
||||
|
||||
apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply
|
||||
|
||||
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
"""
|
||||
The rotary position embeddings from RoFormer_ (Su et. al).
|
||||
A crucial insight from the method is that the query and keys are
|
||||
transformed by rotation matrices which depend on the relative positions.
|
||||
|
||||
Other implementations are available in the Rotary Transformer repo_ and in
|
||||
GPT-NeoX_, GPT-NeoX was an inspiration
|
||||
|
||||
.. _RoFormer: https://arxiv.org/abs/2104.09864
|
||||
.. _repo: https://github.com/ZhuiyiTechnology/roformer
|
||||
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
|
||||
|
||||
If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
|
||||
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
|
||||
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, base=10000.0, interleaved=False, scale_base=None,
|
||||
pos_idx_in_fp32=True, device=None):
|
||||
"""
|
||||
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
||||
of 1st half and 2nd half (GPT-NeoX style).
|
||||
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
|
||||
otherwise they might be in lower precision.
|
||||
This option was added because previously (before 2023-07-02), when we construct
|
||||
the position indices, we use the dtype of self.inv_freq. In most cases this would
|
||||
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
|
||||
self.inv_freq would be bf16, and the position indices are also in bf16.
|
||||
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
|
||||
embeddings for some positions will coincide.
|
||||
To maintain compatibility with models previously trained in pure bf16,
|
||||
we add this option.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.base = float(base)
|
||||
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
||||
# Generate and save the inverse frequency buffer (non trainable)
|
||||
inv_freq = self._compute_inv_freq(device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.interleaved = interleaved
|
||||
self.scale_base = scale_base
|
||||
scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
|
||||
/ (1.4 * dim) if scale_base is not None else None)
|
||||
self.register_buffer("scale", scale, persistent=False)
|
||||
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
|
||||
def _compute_inv_freq(self, device=None):
|
||||
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device,
|
||||
dtype=torch.float32) / self.dim))
|
||||
|
||||
|
||||
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# if we're on a new device (possibly due to tracing for instance),
|
||||
# or if we're switching from inference mode to training
|
||||
if (seqlen > self._seq_len_cached or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
or (self.training and self._cos_cached.is_inference())):
|
||||
self._seq_len_cached = seqlen
|
||||
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
|
||||
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
|
||||
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
|
||||
if self.pos_idx_in_fp32:
|
||||
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
||||
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
|
||||
# will be large. Having it in bf16 will lose a lot of precision and cause the
|
||||
# cos & sin output to change significantly.
|
||||
# We want to recompute self.inv_freq if it was not loaded in fp32
|
||||
if self.inv_freq.dtype != torch.float32:
|
||||
inv_freq = self._compute_inv_freq(device=device)
|
||||
else:
|
||||
inv_freq = self.inv_freq
|
||||
else:
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
inv_freq = self.inv_freq
|
||||
# Don't do einsum, it converts fp32 to fp16 under AMP
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
if self.scale is None:
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
else:
|
||||
power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
|
||||
- seqlen // 2) / self.scale_base)
|
||||
scale = self.scale.to(device=power.device) ** rearrange(power, 's -> s 1')
|
||||
# We want the multiplication by scale to happen in fp32
|
||||
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
||||
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
||||
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
||||
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
||||
|
||||
def forward(self, qkv: torch.Tensor, kv: Optional[torch.Tensor] = None,
|
||||
seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
|
||||
else it's just q of shape (batch, seqlen, nheads, headdim)
|
||||
kv: (batch, seqlen, 2, nheads, headdim)
|
||||
seqlen_offset: can be used in generation where the qkv being passed in is only the last
|
||||
token in the batch.
|
||||
"""
|
||||
seqlen = qkv.shape[1]
|
||||
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
||||
if kv is None:
|
||||
if self.scale is None:
|
||||
return apply_rotary_emb_qkv_(
|
||||
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
|
||||
None, None, self.interleaved
|
||||
)
|
||||
else:
|
||||
return apply_rotary_emb_qkv_(
|
||||
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
|
||||
self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:],
|
||||
self.interleaved
|
||||
)
|
||||
else:
|
||||
q = qkv
|
||||
q = apply_rotary_emb_func(
|
||||
q, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
|
||||
self.interleaved, True
|
||||
)
|
||||
if self.scale is None:
|
||||
kv = apply_rotary_emb_kv_(
|
||||
kv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
|
||||
self.interleaved
|
||||
)
|
||||
else:
|
||||
kv = apply_rotary_emb_kv_(
|
||||
kv, self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:],
|
||||
self.interleaved
|
||||
)
|
||||
return q, kv
|
||||
0
pkgs/xformers/_flash_attn/losses/__init__.py
Normal file
0
pkgs/xformers/_flash_attn/losses/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
129
pkgs/xformers/_flash_attn/losses/cross_entropy.py
Normal file
129
pkgs/xformers/_flash_attn/losses/cross_entropy.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
|
||||
# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
|
||||
# the losses we can get the global loss. There's no need to do it step by step
|
||||
# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
|
||||
# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import xentropy_cuda_lib
|
||||
|
||||
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
|
||||
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
|
||||
# version of PyTorch. The following 2 lines are for backward compatibility with
|
||||
# older PyTorch.
|
||||
if "all_gather_into_tensor" not in dir(torch.distributed):
|
||||
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
|
||||
|
||||
|
||||
class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, logits, labels, smoothing=0.0, ignored_index=-100, inplace_backward=False,
|
||||
process_group=None):
|
||||
"""
|
||||
logits: (batch, vocab_size)
|
||||
labels: (batch,)
|
||||
If process_group is not None, we're doing Tensor Parallel: each process is responsible for
|
||||
one part of the vocab. The loss needs to be aggregated across processes.
|
||||
"""
|
||||
batch, vocab_size = logits.shape
|
||||
assert labels.shape == (batch,)
|
||||
world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
|
||||
ctx.total_classes = world_size * vocab_size
|
||||
|
||||
if world_size == 1:
|
||||
losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing)
|
||||
losses.masked_fill_(labels==ignored_index, 0)
|
||||
labels_local = labels
|
||||
else:
|
||||
rank = torch.distributed.get_rank(process_group)
|
||||
vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
|
||||
|
||||
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
||||
labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
|
||||
ignored_mask = labels == ignored_index
|
||||
labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)
|
||||
|
||||
# For tensor parallel cross entropy with smoothing, we want to pass in the total number
|
||||
# of classes so that smoothing can be applied correctly. If total_classes=-1, use the
|
||||
# last dimension of the input tensor.
|
||||
losses, lse_local = xentropy_cuda_lib.forward(logits, labels_local, smoothing,
|
||||
world_size * vocab_size)
|
||||
assert lse_local.shape == (batch,)
|
||||
assert losses.shape == (batch,)
|
||||
losses.masked_fill_(ignored_mask, 0)
|
||||
# For labels == ignored_index, the loss is always 0.
|
||||
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
|
||||
# lse_local - predicted logit, and 0 otherwise.
|
||||
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
|
||||
# 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes)
|
||||
# For labels not in the vocab of this partition, losses contains
|
||||
# 0.1 * (lse_local - sum logit / total_classes).
|
||||
|
||||
lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype,
|
||||
device=lse_local.device)
|
||||
torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(),
|
||||
group=process_group)
|
||||
handle_losses = torch.distributed.all_reduce(
|
||||
losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
|
||||
)
|
||||
lse = torch.logsumexp(lse_allgather, dim=0)
|
||||
# If there's no smoothing, the total losses are lse_local - predicted_logit,
|
||||
# we just have to subtract the lse_local and add the lse (global).
|
||||
# If there's smoothing=0.1, the total losses are
|
||||
# 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)
|
||||
# We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).
|
||||
rank_per_sample = torch.div(labels, vocab_size, rounding_mode='floor')
|
||||
lse_local = lse_allgather[rank_per_sample,
|
||||
torch.arange(batch, device=lse_allgather.device)]
|
||||
|
||||
handle_losses.wait()
|
||||
if smoothing == 0.0:
|
||||
losses += lse - lse_local
|
||||
else:
|
||||
losses += ((1 - smoothing) * (lse - lse_local)
|
||||
+ smoothing * (lse - lse_allgather.sum(dim=0)))
|
||||
losses.masked_fill_(ignored_mask, 0)
|
||||
|
||||
ctx.save_for_backward(logits, lse, labels_local)
|
||||
ctx.smoothing = smoothing
|
||||
ctx.ignored_index = ignored_index
|
||||
ctx.inplace_backward = inplace_backward
|
||||
return losses
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_loss):
|
||||
logits, lse, labels = ctx.saved_tensors
|
||||
grad_loss = grad_loss.contiguous()
|
||||
grad_loss.masked_fill_(labels==ctx.ignored_index, 0)
|
||||
grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, lse, labels,
|
||||
ctx.smoothing, ctx.inplace_backward,
|
||||
ctx.total_classes)
|
||||
return grad_logits, None, None, None, None, None, None
|
||||
|
||||
|
||||
class CrossEntropyLoss(nn.Module):
|
||||
|
||||
def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
|
||||
inplace_backward=False, process_group=None):
|
||||
super().__init__()
|
||||
if reduction not in ['mean', 'none']:
|
||||
raise NotImplementedError("Only support reduction = 'mean' or 'none'")
|
||||
self.ignore_index = ignore_index
|
||||
self.reduction = reduction
|
||||
self.label_smoothing = label_smoothing
|
||||
self.inplace_backward = inplace_backward
|
||||
self.process_group = process_group
|
||||
|
||||
def forward(self, input, target):
|
||||
assert input.is_cuda and target.is_cuda
|
||||
# SoftmaxCrossEntropyLoss implicitly casts to float
|
||||
loss = SoftmaxCrossEntropyLossFn.apply(
|
||||
input, target, self.label_smoothing, self.ignore_index, self.inplace_backward,
|
||||
self.process_group
|
||||
)
|
||||
if self.reduction == 'mean':
|
||||
return loss.sum() / (target != self.ignore_index).sum()
|
||||
else:
|
||||
return loss
|
||||
0
pkgs/xformers/_flash_attn/models/__init__.py
Normal file
0
pkgs/xformers/_flash_attn/models/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
pkgs/xformers/_flash_attn/models/__pycache__/gpt.cpython-310.pyc
Normal file
BIN
pkgs/xformers/_flash_attn/models/__pycache__/gpt.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
pkgs/xformers/_flash_attn/models/__pycache__/opt.cpython-310.pyc
Normal file
BIN
pkgs/xformers/_flash_attn/models/__pycache__/opt.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/_flash_attn/models/__pycache__/vit.cpython-310.pyc
Normal file
BIN
pkgs/xformers/_flash_attn/models/__pycache__/vit.cpython-310.pyc
Normal file
Binary file not shown.
531
pkgs/xformers/_flash_attn/models/bert.py
Normal file
531
pkgs/xformers/_flash_attn/models/bert.py
Normal file
@@ -0,0 +1,531 @@
|
||||
# Copyright (c) 2022, Tri Dao.
|
||||
# This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
|
||||
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
|
||||
# https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
|
||||
|
||||
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
|
||||
|
||||
import re
|
||||
import logging
|
||||
from functools import partial
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import BertConfig
|
||||
from transformers.models.bert.modeling_bert import BaseModelOutputWithPoolingAndCrossAttentions
|
||||
from transformers.models.bert.modeling_bert import BertForPreTrainingOutput
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from flash_attn.modules.mha import MHA
|
||||
from flash_attn.modules.mlp import Mlp, FusedMLP
|
||||
from flash_attn.modules.block import Block
|
||||
from flash_attn.modules.embedding import BertEmbeddings
|
||||
from flash_attn.bert_padding import unpad_input, pad_input
|
||||
from flash_attn.bert_padding import index_first_axis, index_first_axis_residual
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
|
||||
try:
|
||||
from flash_attn.ops.fused_dense import FusedDense
|
||||
except ImportError:
|
||||
FusedDense = None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.layer_norm import dropout_add_layer_norm, layer_norm
|
||||
except ImportError:
|
||||
dropout_add_layer_norm, layer_norm = None, None
|
||||
|
||||
try:
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||
except ImportError:
|
||||
CrossEntropyLoss = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
||||
use_flash_attn = getattr(config, 'use_flash_attn', False)
|
||||
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
|
||||
rotary_kwargs = {}
|
||||
if config.position_embedding_type == "rotary":
|
||||
rotary_kwargs["rotary_emb_dim"] = getattr(config, "rotary_emb_dim", config.hidden_size)
|
||||
rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
|
||||
rotary_kwargs["rotary_emb_scale_base"] = getattr(config, "rotary_emb_scale_base", None)
|
||||
rotary_kwargs["rotary_emb_interleaved"] = getattr(config, "rotary_emb_interleaved", False)
|
||||
mixer_cls = partial(MHA, num_heads=config.num_attention_heads, cross_attn=cross_attn,
|
||||
dropout=config.attention_probs_dropout_prob, causal=False,
|
||||
fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn,
|
||||
return_residual=return_residual, **rotary_kwargs)
|
||||
return mixer_cls
|
||||
|
||||
|
||||
def create_mlp_cls(config, layer_idx=None, return_residual=False):
|
||||
inner_dim = config.intermediate_size
|
||||
fused_mlp = getattr(config, 'fused_mlp', False)
|
||||
if fused_mlp:
|
||||
assert config.hidden_act in ['gelu_new', 'gelu_fast'], ('fused_mlp only '
|
||||
'supports approximate gelu')
|
||||
if not fused_mlp:
|
||||
approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none'
|
||||
mlp_cls = partial(Mlp, hidden_features=inner_dim,
|
||||
activation=partial(F.gelu, approximate=approximate),
|
||||
return_residual=return_residual)
|
||||
else:
|
||||
if FusedMLP is None:
|
||||
raise ImportError('fused_dense is not installed')
|
||||
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
|
||||
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
|
||||
if isinstance(mlp_checkpoint_lvl, Sequence):
|
||||
assert layer_idx is not None
|
||||
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
|
||||
mlp_cls = partial(FusedMLP, hidden_features=inner_dim,
|
||||
checkpoint_lvl=mlp_checkpoint_lvl, return_residual=return_residual)
|
||||
return mlp_cls
|
||||
|
||||
|
||||
def create_block(config, layer_idx=None):
|
||||
last_layer_subset = getattr(config, 'last_layer_subset', False)
|
||||
cross_attn=last_layer_subset and layer_idx == config.num_hidden_layers - 1
|
||||
# TD [2022-12-19]: For cross attention (last layer), we actually want to return the
|
||||
# residual x_kv, not residual x. But it's annoying to change the API (and it only affects
|
||||
# one layer) so we just choose not to return residual in this case.
|
||||
return_residual = not cross_attn
|
||||
mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
|
||||
mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
|
||||
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
|
||||
block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
|
||||
prenorm=False, resid_dropout1=config.hidden_dropout_prob,
|
||||
resid_dropout2=config.hidden_dropout_prob,
|
||||
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
|
||||
return_residual=return_residual)
|
||||
return block
|
||||
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
|
||||
def _init_weights(module, initializer_range=0.02):
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, std=initializer_range)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
nn.init.normal_(module.weight, std=initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
nn.init.zeros_(module.weight[module.padding_idx])
|
||||
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
|
||||
def __init__(self, config: BertConfig):
|
||||
super().__init__()
|
||||
self.use_flash_attn = getattr(config, 'use_flash_attn', False)
|
||||
self.layers = nn.ModuleList([create_block(config, layer_idx=i)
|
||||
for i in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
|
||||
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
||||
This means that we only compute the last layer output for these tokens.
|
||||
subset_mask: (batch, seqlen), dtype=torch.bool
|
||||
"""
|
||||
if key_padding_mask is None or not self.use_flash_attn:
|
||||
mixer_kwargs = ({'key_padding_mask': key_padding_mask}
|
||||
if key_padding_mask is not None else None)
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
||||
if subset_mask is not None:
|
||||
hidden_states = hidden_states[subset_mask]
|
||||
else:
|
||||
batch, seqlen = hidden_states.shape[:2]
|
||||
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
|
||||
hidden_states, key_padding_mask
|
||||
)
|
||||
mixer_kwargs = {'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen_in_batch}
|
||||
if subset_mask is None:
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
||||
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
||||
else:
|
||||
for layer in self.layers[:-1]:
|
||||
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
||||
if key_padding_mask is not None:
|
||||
subset_idx = torch.nonzero(subset_mask[key_padding_mask], as_tuple=False).flatten()
|
||||
subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32)
|
||||
subset_cu_seqlens = F.pad(torch.cumsum(subset_seqlens, dim=0,
|
||||
dtype=torch.torch.int32), (1, 0))
|
||||
else:
|
||||
subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
|
||||
subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
|
||||
subset_cu_seqlens = F.pad(torch.cumsum(subset_seqlens, dim=0,
|
||||
dtype=torch.torch.int32), (1, 0))
|
||||
hidden_states_subset, hidden_states = index_first_axis_residual(
|
||||
hidden_states, subset_idx
|
||||
)
|
||||
# It's ok to set max_seqlen_q to be much larger
|
||||
mixer_kwargs = {'x_kv': hidden_states,
|
||||
'cu_seqlens': subset_cu_seqlens, 'max_seqlen': max_seqlen_in_batch,
|
||||
'cu_seqlens_k': cu_seqlens, 'max_seqlen_k': max_seqlen_in_batch}
|
||||
hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertPooler(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
|
||||
if fused_bias_fc and FusedDense is None:
|
||||
raise ImportError('fused_dense is not installed')
|
||||
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
||||
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def forward(self, hidden_states, pool=True):
|
||||
# We "pool" the model by simply taking the hidden state corresponding
|
||||
# to the first token.
|
||||
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
|
||||
|
||||
class BertPredictionHeadTransform(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
|
||||
if fused_bias_fc and FusedDense is None:
|
||||
raise ImportError('fused_dense is not installed')
|
||||
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
|
||||
if self.fused_dropout_add_ln and layer_norm is None:
|
||||
raise ImportError('dropout_add_layer_norm is not installed')
|
||||
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
||||
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
||||
approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none'
|
||||
self.transform_act_fn = nn.GELU(approximate=approximate)
|
||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.transform_act_fn(hidden_states)
|
||||
if not self.fused_dropout_add_ln:
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
else:
|
||||
hidden_states = layer_norm(hidden_states, self.layer_norm.weight, self.layer_norm.bias,
|
||||
self.layer_norm.eps)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertLMPredictionHead(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
|
||||
if fused_bias_fc and FusedDense is None:
|
||||
raise ImportError('fused_dense is not installed')
|
||||
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
||||
|
||||
self.transform = BertPredictionHeadTransform(config)
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertPreTrainingHeads(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.predictions = BertLMPredictionHead(config)
|
||||
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
||||
|
||||
def forward(self, sequence_output, pooled_output):
|
||||
prediction_scores = self.predictions(sequence_output)
|
||||
seq_relationship_score = self.seq_relationship(pooled_output)
|
||||
return prediction_scores, seq_relationship_score
|
||||
|
||||
|
||||
class BertPreTrainedModel(nn.Module):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__()
|
||||
if not isinstance(config, BertConfig):
|
||||
raise ValueError(
|
||||
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
|
||||
"To create a model from a Google pretrained model use "
|
||||
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
||||
self.__class__.__name__, self.__class__.__name__
|
||||
))
|
||||
self.config = config
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name, config, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
|
||||
Params:
|
||||
pretrained_model_name_or_path: either:
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `bert_config.json` a configuration file for the model
|
||||
. `pytorch_model.bin` a PyTorch dump of a BertForPretraining instance
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `bert_config.json` a configuration file for the model
|
||||
. `model.chkpt` a TensorFlow checkpoint
|
||||
*inputs, **kwargs: additional input for the specific Bert class
|
||||
(ex: num_labels for BertForSequenceClassification)
|
||||
"""
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
load_return = model.load_state_dict(remap_state_dict(state_dict_from_pretrained(model_name),
|
||||
config), strict=False)
|
||||
logger.info(load_return)
|
||||
return model
|
||||
|
||||
|
||||
class BertModel(BertPreTrainedModel):
|
||||
|
||||
def __init__(self, config: BertConfig, add_pooling_layer=True):
|
||||
super().__init__(config)
|
||||
self.pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
if config.vocab_size % self.pad_vocab_size_multiple != 0:
|
||||
config.vocab_size += (self.pad_vocab_size_multiple
|
||||
- (config.vocab_size % self.pad_vocab_size_multiple))
|
||||
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
|
||||
if self.fused_dropout_add_ln and layer_norm is None:
|
||||
raise ImportError('dropout_add_layer_norm is not installed')
|
||||
assert config.hidden_act in ['gelu', 'gelu_new', 'gelu_fast']
|
||||
|
||||
self.embeddings = BertEmbeddings(config.hidden_size, config.vocab_size,
|
||||
config.max_position_embeddings, config.type_vocab_size,
|
||||
padding_idx=config.pad_token_id)
|
||||
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.encoder = BertEncoder(config)
|
||||
self.pooler = BertPooler(config) if add_pooling_layer else None
|
||||
|
||||
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None,
|
||||
masked_tokens_mask=None):
|
||||
"""If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
|
||||
we only want the output for the masked tokens. This means that we only compute the last
|
||||
layer output for these tokens.
|
||||
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
|
||||
"""
|
||||
hidden_states = self.embeddings(input_ids, position_ids=position_ids,
|
||||
token_type_ids=token_type_ids)
|
||||
# TD [2022-12:18]: Don't need to force residual in fp32
|
||||
# BERT puts embedding LayerNorm before embedding dropout.
|
||||
if not self.fused_dropout_add_ln:
|
||||
hidden_states = self.emb_ln(hidden_states)
|
||||
else:
|
||||
hidden_states = layer_norm(hidden_states, self.emb_ln.weight, self.emb_ln.bias,
|
||||
self.emb_ln.eps)
|
||||
hidden_states = self.emb_drop(hidden_states)
|
||||
|
||||
if masked_tokens_mask is not None:
|
||||
batch_size, seqlen = input_ids.shape[:2]
|
||||
# We also need the first column for the CLS token
|
||||
first_col_mask = torch.zeros(batch_size, seqlen, dtype=torch.bool,
|
||||
device=input_ids.device)
|
||||
first_col_mask[:, 0] = True
|
||||
subset_mask = masked_tokens_mask | first_col_mask
|
||||
else:
|
||||
subset_mask = None
|
||||
|
||||
sequence_output = self.encoder(hidden_states, key_padding_mask=attention_mask,
|
||||
subset_mask=subset_mask)
|
||||
|
||||
if masked_tokens_mask is None:
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
else:
|
||||
# TD [2022-03-01]: the indexing here is very tricky.
|
||||
if attention_mask is not None:
|
||||
subset_idx = subset_mask[attention_mask]
|
||||
pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
|
||||
sequence_output = sequence_output[masked_tokens_mask[attention_mask][subset_idx]]
|
||||
else:
|
||||
pool_input = sequence_output[first_col_mask[subset_mask]]
|
||||
sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
|
||||
pooled_output = (self.pooler(pool_input, pool=False)
|
||||
if self.pooler is not None else None)
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
)
|
||||
|
||||
|
||||
class BertForPreTraining(BertPreTrainedModel):
|
||||
|
||||
def __init__(self, config: BertConfig):
|
||||
super().__init__(config)
|
||||
# If dense_seq_output, we only need to pass the hidden states for the masked out tokens
|
||||
# (around 15%) to the classifier heads.
|
||||
self.dense_seq_output = getattr(config, 'dense_seq_output', False)
|
||||
# If last_layer_subset, we only need the compute the last layer for a subset of tokens
|
||||
# (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
|
||||
self.last_layer_subset = getattr(config, 'last_layer_subset', False)
|
||||
if self.last_layer_subset:
|
||||
assert self.dense_seq_output, 'last_layer_subset requires dense_seq_output'
|
||||
use_xentropy = getattr(config, 'use_xentropy', False)
|
||||
if use_xentropy and CrossEntropyLoss is None:
|
||||
raise ImportError('xentropy_cuda is not installed')
|
||||
loss_cls = (nn.CrossEntropyLoss if not use_xentropy
|
||||
else partial(CrossEntropyLoss, inplace_backward=True))
|
||||
|
||||
self.bert = BertModel(config)
|
||||
self.cls = BertPreTrainingHeads(config)
|
||||
self.mlm_loss = loss_cls(ignore_index=0)
|
||||
self.nsp_loss = loss_cls(ignore_index=-1)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None,
|
||||
labels=None, next_sentence_label=None):
|
||||
"""
|
||||
If labels are provided, they must be 0 for masked out tokens (as specified in the attention
|
||||
mask).
|
||||
Outputs:
|
||||
if `labels` and `next_sentence_label` are not `None`:
|
||||
Outputs the total_loss which is the sum of the masked language modeling loss and the next
|
||||
sentence classification loss.
|
||||
if `labels` or `next_sentence_label` is `None`:
|
||||
Outputs a tuple comprising
|
||||
- the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
|
||||
- the next sentence classification logits of shape [batch_size, 2].
|
||||
|
||||
"""
|
||||
masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
|
||||
outputs = self.bert(
|
||||
input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask.bool() if attention_mask is not None else None,
|
||||
masked_tokens_mask=masked_tokens_mask
|
||||
)
|
||||
sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
|
||||
if self.dense_seq_output and labels is not None:
|
||||
masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
|
||||
if not self.last_layer_subset:
|
||||
sequence_output = index_first_axis(rearrange(sequence_output, 'b s d -> (b s) d'),
|
||||
masked_token_idx)
|
||||
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
||||
|
||||
total_loss = None
|
||||
if labels is not None and next_sentence_label is not None:
|
||||
if self.dense_seq_output and labels is not None: # prediction_scores are already flattened
|
||||
masked_lm_loss = self.mlm_loss(prediction_scores,
|
||||
labels.flatten()[masked_token_idx])
|
||||
else:
|
||||
masked_lm_loss = self.mlm_loss(rearrange(prediction_scores, '... v -> (...) v'),
|
||||
rearrange(labels, '... -> (...)'))
|
||||
next_sentence_loss = self.nsp_loss(rearrange(seq_relationship_score, '... t -> (...) t'),
|
||||
rearrange(next_sentence_label, '... -> (...)'))
|
||||
total_loss = masked_lm_loss.float() + next_sentence_loss.float()
|
||||
|
||||
return BertForPreTrainingOutput(
|
||||
loss=total_loss,
|
||||
prediction_logits=prediction_scores,
|
||||
seq_relationship_logits=seq_relationship_score,
|
||||
)
|
||||
|
||||
|
||||
def remap_state_dict(state_dict, config):
|
||||
# LayerNorm
|
||||
def key_mapping_ln_gamma_beta(key):
|
||||
key = re.sub(r'LayerNorm.gamma$', 'LayerNorm.weight', key)
|
||||
key = re.sub(r'LayerNorm.beta$', 'LayerNorm.bias', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Layers
|
||||
def key_mapping_layers(key):
|
||||
return re.sub(r'^bert.encoder.layer.', 'bert.encoder.layers.', key)
|
||||
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r'^bert.embeddings.LayerNorm.', 'bert.emb_ln.', key)
|
||||
key = re.sub(r'^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)',
|
||||
r'bert.encoder.layers.\1.norm1.\2', key)
|
||||
key = re.sub(r'^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)',
|
||||
r'bert.encoder.layers.\1.norm2.\2', key)
|
||||
key = re.sub(r'^cls.predictions.transform.LayerNorm.(weight|bias)',
|
||||
r'cls.predictions.transform.layer_norm.\1', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
# MLP
|
||||
def key_mapping_mlp(key):
|
||||
key = re.sub(r'^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)',
|
||||
r'bert.encoder.layers.\1.mlp.fc1.\2', key)
|
||||
key = re.sub(r'^bert.encoder.layers.(\d+).output.dense.(weight|bias)',
|
||||
r'bert.encoder.layers.\1.mlp.fc2.\2', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Attention
|
||||
last_layer_subset = getattr(config, 'last_layer_subset', False)
|
||||
for d in range(config.num_hidden_layers):
|
||||
Wq = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.query.weight')
|
||||
Wk = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.key.weight')
|
||||
Wv = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.value.weight')
|
||||
bq = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.query.bias')
|
||||
bk = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.key.bias')
|
||||
bv = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.value.bias')
|
||||
if not (last_layer_subset and d == config.num_hidden_layers - 1):
|
||||
state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.weight'] = torch.cat(
|
||||
[Wq, Wk, Wv], dim=0
|
||||
)
|
||||
state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.bias'] = torch.cat([bq, bk, bv], dim=0)
|
||||
else:
|
||||
state_dict[f'bert.encoder.layers.{d}.mixer.Wq.weight'] = Wq
|
||||
state_dict[f'bert.encoder.layers.{d}.mixer.Wkv.weight'] = torch.cat(
|
||||
[Wk, Wv], dim=0
|
||||
)
|
||||
state_dict[f'bert.encoder.layers.{d}.mixer.Wq.bias'] = bq
|
||||
state_dict[f'bert.encoder.layers.{d}.mixer.Wkv.bias'] = torch.cat([bk, bv], dim=0)
|
||||
def key_mapping_attn(key):
|
||||
return re.sub(r'^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)',
|
||||
r'bert.encoder.layers.\1.mixer.out_proj.\2', key)
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
|
||||
def key_mapping_decoder_bias(key):
|
||||
return re.sub(r'^cls.predictions.bias', 'cls.predictions.decoder.bias', key)
|
||||
state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Word embedding
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
if pad_vocab_size_multiple > 1:
|
||||
word_embeddings = state_dict['bert.embeddings.word_embeddings.weight']
|
||||
state_dict['bert.embeddings.word_embeddings.weight'] = F.pad(
|
||||
word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
decoder_weight = state_dict['cls.predictions.decoder.weight']
|
||||
state_dict['cls.predictions.decoder.weight'] = F.pad(
|
||||
decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
|
||||
)
|
||||
# If the vocab was padded, we want to set the decoder bias for those padded indices to be
|
||||
# strongly negative (i.e. the decoder shouldn't predict those indices).
|
||||
# TD [2022-05-09]: I don't think it affects the MLPerf training.
|
||||
decoder_bias = state_dict['cls.predictions.decoder.bias']
|
||||
state_dict['cls.predictions.decoder.bias'] = F.pad(
|
||||
decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
|
||||
)
|
||||
|
||||
return state_dict
|
||||
122
pkgs/xformers/_flash_attn/models/falcon.py
Normal file
122
pkgs/xformers/_flash_attn/models/falcon.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
import math
|
||||
import re
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from transformers import GPT2Config, FalconConfig
|
||||
|
||||
|
||||
def remap_state_dict_hf_falcon(state_dict, config):
|
||||
def key_mapping_layers(key):
|
||||
return re.sub(r'^transformer.h.', 'transformer.layers.', key)
|
||||
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
||||
# Word embedding
|
||||
def key_mapping_emb(key):
|
||||
return re.sub(r'^transformer.word_embeddings.', 'transformer.embeddings.word_embeddings.', key)
|
||||
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
||||
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
|
||||
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
|
||||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
if getattr(config, 'tie_word_embeddings'):
|
||||
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
|
||||
else:
|
||||
output_embeddings = state_dict.pop('lm_head.weight')
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
state_dict['lm_head.weight'] = F.pad(
|
||||
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
|
||||
)
|
||||
output_embeddings_bias = state_dict.pop('lm_head.bias')
|
||||
state_dict['lm_head.bias'] = F.pad(
|
||||
output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])
|
||||
)
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r'^transformer.layers.(\d+).input_layernorm.',
|
||||
r'transformer.layers.\1.norm1.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.',
|
||||
r'transformer.layers.\1.norm2.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).ln_attn.', r'transformer.layers.\1.norm1.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).ln_mlp.', r'transformer.layers.\1.norm2.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
# MLP
|
||||
def key_mapping_mlp(key):
|
||||
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.',
|
||||
r'transformer.layers.\1.mlp.fc1.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.',
|
||||
r'transformer.layers.\1.mlp.fc2.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
def key_mapping_attn(key):
|
||||
key = re.sub(r'^transformer.layers.(\d+).self_attention.query_key_value.',
|
||||
r'transformer.layers.\1.mixer.Wqkv.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).self_attention.dense.',
|
||||
r'transformer.layers.\1.mixer.out_proj.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
n_head = config.n_head
|
||||
n_head_kv = getattr(config, "n_head_kv", 1)
|
||||
headdim = config.hidden_size // n_head
|
||||
for l in range(config.n_layer):
|
||||
# The weights are stored in a different layout compared to our implementation
|
||||
Wqkv = rearrange(state_dict.pop(f'transformer.layers.{l}.mixer.Wqkv.weight'),
|
||||
"(group ratio headdim) ... -> group ratio headdim ...",
|
||||
ratio=n_head // n_head_kv + 2, headdim=headdim)
|
||||
Wq = rearrange(Wqkv[:, :-2], "group ratio headdim ... -> (group ratio headdim) ...")
|
||||
Wk = rearrange(Wqkv[:, [-2]], "group ratio headdim ... -> (group ratio headdim) ...")
|
||||
Wv = rearrange(Wqkv[:, [-1]], "group ratio headdim ... -> (group ratio headdim) ...")
|
||||
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def falcon_config_to_gpt2_config(falcon_config: FalconConfig) -> GPT2Config:
|
||||
# The 40b config uses "n_head_kv" instead of "num_kv_heads"
|
||||
n_head_kv = getattr(falcon_config, "n_head_kv",
|
||||
1 if getattr(falcon_config, "multi_query", False)
|
||||
else falcon_config.n_head)
|
||||
# HACK: the 40b config has 2 LN per layer instead of 1, but that's not reflected in the config.
|
||||
# So we have to infer it from the number of heads in the key/value block
|
||||
parallel_block_tied_norm = n_head_kv == 1
|
||||
return GPT2Config(
|
||||
vocab_size=falcon_config.vocab_size,
|
||||
n_positions=0, # No absolute position embedding
|
||||
n_embd=falcon_config.hidden_size,
|
||||
n_layer=falcon_config.n_layer,
|
||||
n_head=falcon_config.n_head,
|
||||
n_inner=falcon_config.hidden_size * 4,
|
||||
activation_function="gelu",
|
||||
resid_pdrop=falcon_config.hidden_dropout,
|
||||
embd_pdrop=0.0, # There doesn't seem to be any embedding dropout
|
||||
attn_pdrop=falcon_config.attention_dropout,
|
||||
layer_norm_epsilon=falcon_config.layer_norm_epsilon,
|
||||
initializer_range=falcon_config.initializer_range,
|
||||
bos_token_id=falcon_config.bos_token_id,
|
||||
eos_token_id=falcon_config.eos_token_id,
|
||||
# These are new arguments not in the original GPT2Config
|
||||
parallel_block=falcon_config.parallel_attn,
|
||||
n_head_kv=n_head_kv,
|
||||
parallel_block_tied_norm=parallel_block_tied_norm,
|
||||
rotary_emb_fraction=1.0,
|
||||
rotary_emb_interleaved=False,
|
||||
tie_word_embeddings=True,
|
||||
qkv_proj_bias=falcon_config.bias,
|
||||
out_proj_bias=falcon_config.bias,
|
||||
mlp_fc1_bias=falcon_config.bias,
|
||||
mlp_fc2_bias=falcon_config.bias,
|
||||
lm_head_bias=False,
|
||||
)
|
||||
740
pkgs/xformers/_flash_attn/models/gpt.py
Normal file
740
pkgs/xformers/_flash_attn/models/gpt.py
Normal file
@@ -0,0 +1,740 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from functools import partial
|
||||
|
||||
from collections import namedtuple, OrderedDict
|
||||
from collections.abc import Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import GPT2Config
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from flash_attn.ops.activations import sqrelu_fwd
|
||||
from flash_attn.modules.mha import MHA, ParallelMHA
|
||||
from flash_attn.modules.mlp import Mlp, GatedMlp, ParallelMLP, FusedMLP, ParallelFusedMLP
|
||||
from flash_attn.modules.block import Block, ParallelBlock
|
||||
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
|
||||
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
from flash_attn.utils.generation import GenerationMixin
|
||||
from flash_attn.models.opt import remap_state_dict_hf_opt
|
||||
from flash_attn.models.gptj import remap_state_dict_hf_gptj
|
||||
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox
|
||||
from flash_attn.models.falcon import remap_state_dict_hf_falcon
|
||||
|
||||
try:
|
||||
from flash_attn.ops.fused_dense import ColumnParallelLinear
|
||||
except ImportError:
|
||||
ColumnParallelLinear = None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.layer_norm import dropout_add_layer_norm
|
||||
except ImportError:
|
||||
dropout_add_layer_norm = None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
|
||||
except ImportError:
|
||||
dropout_add_layer_norm_parallel_residual = None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm
|
||||
except ImportError:
|
||||
RMSNorm, dropout_add_rms_norm = None, None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
|
||||
except ImportError:
|
||||
dropout_add_rms_norm_parallel_residual = None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
|
||||
except ImportError:
|
||||
FusedDenseSqreluDense = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads)
|
||||
softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5)
|
||||
if config.scale_attn_by_inverse_layer_idx:
|
||||
assert layer_idx is not None
|
||||
softmax_scale /= float(layer_idx + 1)
|
||||
dwconv = getattr(config, 'attn_dwconv', False)
|
||||
if dwconv:
|
||||
assert process_group is None, 'TensorParallel MHA does not support dwconv yet'
|
||||
qkv_proj_bias = getattr(config, 'qkv_proj_bias', True)
|
||||
out_proj_bias = getattr(config, 'out_proj_bias', True)
|
||||
rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim)
|
||||
rotary_emb_base = getattr(config, 'rotary_emb_base', 10000.0)
|
||||
rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', None)
|
||||
rotary_emb_interleaved = getattr(config, 'rotary_emb_interleaved', False)
|
||||
use_flash_attn = getattr(config, 'use_flash_attn', False)
|
||||
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
|
||||
if not fused_bias_fc:
|
||||
assert process_group is None, 'TensorParallel MHA requires fused_bias_fc'
|
||||
mha_cls = MHA if process_group is None else ParallelMHA
|
||||
serial_kwargs = ({'fused_bias_fc': fused_bias_fc, 'dwconv': dwconv}
|
||||
if process_group is None else {})
|
||||
parallel_kwargs = ({'process_group': process_group,
|
||||
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
|
||||
if process_group is not None else {})
|
||||
num_heads_kv = getattr(config, "n_head_kv", None)
|
||||
mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads,
|
||||
num_heads_kv=num_heads_kv,
|
||||
qkv_proj_bias=qkv_proj_bias, out_proj_bias=out_proj_bias,
|
||||
dropout=config.attn_pdrop,
|
||||
softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx,
|
||||
rotary_emb_dim=rotary_emb_dim, rotary_emb_base=rotary_emb_base,
|
||||
rotary_emb_scale_base=rotary_emb_scale_base,
|
||||
rotary_emb_interleaved=rotary_emb_interleaved,
|
||||
use_flash_attn=use_flash_attn,
|
||||
**serial_kwargs, **parallel_kwargs, **factory_kwargs)
|
||||
return mixer_cls
|
||||
|
||||
|
||||
def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
mlp_fc1_bias = getattr(config, 'mlp_fc1_bias', True)
|
||||
mlp_fc2_bias = getattr(config, 'mlp_fc2_bias', True)
|
||||
fused_mlp = getattr(config, 'fused_mlp', False)
|
||||
if fused_mlp:
|
||||
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu']
|
||||
fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False)
|
||||
if fused_dense_sqrelu_dense:
|
||||
assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only '
|
||||
'supports approximate activation_function sqrelu')
|
||||
assert not (fused_dense_sqrelu_dense and fused_mlp)
|
||||
if not fused_mlp and not fused_dense_sqrelu_dense:
|
||||
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx', 'relu',
|
||||
'sqrelu', 'glu', 'swiglu', 'geglu']
|
||||
if config.activation_function in ['glu', 'swiglu', 'geglu']:
|
||||
activation = (F.sigmoid if config.activation_function == 'glu'
|
||||
else (F.silu if config.activation_function == 'swiglu'
|
||||
else F.gelu))
|
||||
mlp_cls = partial(GatedMlp, hidden_features=config.n_inner, activation=activation,
|
||||
bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, **factory_kwargs)
|
||||
else:
|
||||
if config.activation_function == 'relu':
|
||||
activation = partial(F.relu, inplace=True)
|
||||
elif config.activation_function == 'sqrelu':
|
||||
activation = sqrelu_fwd
|
||||
else:
|
||||
approximate = ('tanh' if config.activation_function
|
||||
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none')
|
||||
activation=partial(F.gelu, approximate=approximate)
|
||||
mlp_cls = Mlp if process_group is None else ParallelMLP
|
||||
parallel_kwargs = ({'process_group': process_group,
|
||||
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
|
||||
if process_group is not None else {})
|
||||
mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation,
|
||||
bias1=mlp_fc1_bias, bias2=mlp_fc2_bias,
|
||||
**parallel_kwargs, **factory_kwargs)
|
||||
else:
|
||||
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
|
||||
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
|
||||
if isinstance(mlp_checkpoint_lvl, Sequence):
|
||||
assert layer_idx is not None
|
||||
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
|
||||
if fused_mlp:
|
||||
if FusedMLP is None:
|
||||
raise ImportError('fused_dense is not installed')
|
||||
activation = ('gelu_approx' if config.activation_function
|
||||
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else config.activation_function)
|
||||
mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
|
||||
parallel_kwargs = ({'process_group': process_group,
|
||||
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
|
||||
if process_group is not None else {})
|
||||
mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation,
|
||||
checkpoint_lvl=mlp_checkpoint_lvl,
|
||||
bias1=mlp_fc1_bias, bias2=mlp_fc2_bias,
|
||||
**parallel_kwargs, **factory_kwargs)
|
||||
elif fused_dense_sqrelu_dense:
|
||||
assert FusedDenseSqreluDense is not None
|
||||
mlp_cls = partial(FusedDenseSqreluDense, hidden_features=config.n_inner,
|
||||
checkpoint_lvl=mlp_checkpoint_lvl, **factory_kwargs)
|
||||
else:
|
||||
raise RuntimeError('MLP type not supported')
|
||||
return mlp_cls
|
||||
|
||||
|
||||
def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
sequence_parallel = getattr(config, 'sequence_parallel', True)
|
||||
mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
|
||||
mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
|
||||
use_rms_norm = getattr(config, 'rms_norm', False)
|
||||
norm_cls = partial(nn.LayerNorm if not use_rms_norm else RMSNorm,
|
||||
eps=config.layer_norm_epsilon, **factory_kwargs)
|
||||
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
|
||||
residual_in_fp32 = getattr(config, 'residual_in_fp32', False)
|
||||
resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop
|
||||
prenorm = getattr(config, 'prenorm', True)
|
||||
parallel_block = getattr(config, 'parallel_block', False)
|
||||
if not parallel_block:
|
||||
block = Block(
|
||||
config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
|
||||
prenorm=prenorm, resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop,
|
||||
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
sequence_parallel=sequence_parallel and process_group is not None,
|
||||
mark_shared_params=process_group is not None
|
||||
)
|
||||
else:
|
||||
assert prenorm
|
||||
block = ParallelBlock(
|
||||
config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
|
||||
resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop,
|
||||
tied_norm=getattr(config, 'parallel_block_tied_norm', False),
|
||||
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
sequence_parallel=sequence_parallel and process_group is not None,
|
||||
mark_shared_params=process_group is not None
|
||||
)
|
||||
block.layer_idx = layer_idx
|
||||
return block
|
||||
|
||||
|
||||
class GPTPreTrainedModel(nn.Module):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__()
|
||||
if not isinstance(config, GPT2Config):
|
||||
raise ValueError(
|
||||
"Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
|
||||
"To create a model from a Google pretrained model use "
|
||||
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
||||
self.__class__.__name__, self.__class__.__name__
|
||||
))
|
||||
self.config = config
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name, config, *args, strict=True, device=None, dtype=None,
|
||||
world_size=1, rank=0, **kwargs):
|
||||
"""
|
||||
Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
"""
|
||||
# Instantiate model.
|
||||
model = cls(config, *args, device=device, dtype=dtype, **kwargs)
|
||||
# Load state_dict in cpu because we already initialized the model in GPU, and we don't
|
||||
# want extra stuff taking up more GPU memory
|
||||
state_dict = state_dict_from_pretrained(
|
||||
model_name, device='cpu', dtype=dtype
|
||||
)
|
||||
if model_name.startswith('gpt2'):
|
||||
state_dict = remap_state_dict_hf_gpt2(state_dict, config)
|
||||
elif model_name.startswith('facebook/opt'):
|
||||
state_dict = remap_state_dict_hf_opt(state_dict, config)
|
||||
elif model_name.startswith('EleutherAI/gpt-j-'):
|
||||
state_dict = remap_state_dict_hf_gptj(state_dict, config)
|
||||
elif model_name.startswith('EleutherAI/gpt-neox-'):
|
||||
state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)
|
||||
elif model_name.startswith('tiiuae/falcon-'):
|
||||
state_dict = remap_state_dict_hf_falcon(state_dict, config)
|
||||
else:
|
||||
raise NotImplementedError(f'Model {model_name} not supported')
|
||||
if world_size > 1:
|
||||
state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
|
||||
load_return = model.load_state_dict(state_dict, strict=strict)
|
||||
logger.info(load_return)
|
||||
return model
|
||||
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
||||
def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True):
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, std=initializer_range)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
nn.init.normal_(module.weight, std=initializer_range)
|
||||
|
||||
if rescale_prenorm_residual:
|
||||
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
||||
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
||||
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
||||
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
||||
#
|
||||
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
||||
for name, p in module.named_parameters():
|
||||
if name in ["out_proj.weight", "fc2.weight"]:
|
||||
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
||||
nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))
|
||||
|
||||
|
||||
class GPTModel(GPTPreTrainedModel):
|
||||
|
||||
def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
|
||||
super().__init__(config)
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = getattr(config, 'sequence_parallel', True)
|
||||
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx',
|
||||
'relu', 'sqrelu', 'glu', 'swiglu', 'geglu']
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
|
||||
* pad_vocab_size_multiple)
|
||||
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
|
||||
self.residual_in_fp32 = getattr(config, 'residual_in_fp32', False)
|
||||
# These 2 options are for OPT-350m
|
||||
self.prenorm = getattr(config, 'prenorm', True)
|
||||
use_rms_norm = getattr(config, 'rms_norm', False)
|
||||
word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None)
|
||||
# For GPT-J, GPT-NeoX
|
||||
self.parallel_block = getattr(config, 'parallel_block', False)
|
||||
|
||||
if process_group is None:
|
||||
self.embeddings = GPT2Embeddings(
|
||||
config.hidden_size, vocab_size, config.max_position_embeddings,
|
||||
word_embed_proj_dim=word_embed_proj_dim, **factory_kwargs
|
||||
)
|
||||
else:
|
||||
self.embeddings = ParallelGPT2Embeddings(
|
||||
config.hidden_size, vocab_size, config.max_position_embeddings,
|
||||
process_group=process_group, sequence_parallel=self.sequence_parallel,
|
||||
**factory_kwargs
|
||||
)
|
||||
|
||||
# We change the order of dropout, residual and layer norm:
|
||||
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
|
||||
# Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
|
||||
# the main branch (output of MLP). The model definition is unchanged, but the mapping of the
|
||||
# nn.Dropout probabilities are changed.
|
||||
# This is for performance reason: we can fuse dropout + add + layer_norm.
|
||||
self.layers = nn.ModuleList([create_block(config, layer_idx=i, process_group=process_group,
|
||||
**factory_kwargs)
|
||||
for i in range(config.num_hidden_layers)])
|
||||
|
||||
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
|
||||
if self.fused_dropout_add_ln:
|
||||
if ((not self.parallel_block and dropout_add_layer_norm is None)
|
||||
or (self.parallel_block and dropout_add_layer_norm_parallel_residual is None)):
|
||||
raise ImportError('dropout_layer_norm is not installed')
|
||||
if self.prenorm:
|
||||
self.drop_f = nn.Dropout(config.resid_pdrop)
|
||||
norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm
|
||||
self.ln_f = norm_cls(config.hidden_size, eps=config.layer_norm_epsilon,
|
||||
**factory_kwargs)
|
||||
if process_group is not None:
|
||||
for p in self.ln_f.parameters():
|
||||
# Mark the norm parameters as "shared_params" so that we sync their values at init.
|
||||
p._shared_params = True
|
||||
# Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
|
||||
if self.sequence_parallel:
|
||||
p._sequence_parallel = True
|
||||
|
||||
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
|
||||
initializer_range=config.initializer_range))
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
if self.process_group is not None:
|
||||
sync_shared_params(self, self.process_group)
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
return {i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
||||
for i, layer in enumerate(self.layers)}
|
||||
|
||||
def forward(self, input_ids, position_ids=None, inference_params=None):
|
||||
# If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
|
||||
# dimensions so that we can split on it easily, in case of small batch size.
|
||||
# Only the attention layers need to know the seqlen.
|
||||
embedding_kwargs = ({'combine_batch_seqlen_dim': True}
|
||||
if self.process_group is not None and self.sequence_parallel else {})
|
||||
hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
|
||||
if self.parallel_block:
|
||||
hidden_states2 = None
|
||||
residual = None
|
||||
mixer_kwargs = ({'seqlen': input_ids.shape[1]}
|
||||
if self.process_group is not None and self.sequence_parallel else {})
|
||||
if inference_params is not None:
|
||||
mixer_kwargs['inference_params'] = inference_params
|
||||
for layer in self.layers:
|
||||
if self.prenorm:
|
||||
if not self.parallel_block:
|
||||
hidden_states, residual = layer(hidden_states, residual,
|
||||
mixer_kwargs=mixer_kwargs)
|
||||
else:
|
||||
hidden_states, hidden_states2, residual = layer(
|
||||
hidden_states, hidden_states2, residual, mixer_kwargs=mixer_kwargs
|
||||
)
|
||||
else:
|
||||
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
||||
if self.prenorm:
|
||||
if not self.fused_dropout_add_ln:
|
||||
dropped = self.drop_f(hidden_states)
|
||||
if not self.parallel_block:
|
||||
residual = (dropped + residual) if residual is not None else dropped
|
||||
else:
|
||||
dropped2 = self.drop_f(hidden_states2)
|
||||
residual = ((residual + dropped + dropped2)
|
||||
if residual is not None else dropped + dropped2)
|
||||
hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
|
||||
else:
|
||||
# Set prenorm=False here since we don't need the residual
|
||||
if not self.parallel_block:
|
||||
fused_add_norm_fn = (dropout_add_rms_norm if isinstance(self.ln_f, RMSNorm)
|
||||
else dropout_add_layer_norm)
|
||||
hidden_states = fused_add_norm_fn(
|
||||
hidden_states, residual, self.ln_f.weight, self.ln_f.bias,
|
||||
self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False,
|
||||
residual_in_fp32=self.residual_in_fp32
|
||||
)
|
||||
else:
|
||||
fused_add_norm_fn = (dropout_add_rms_norm_parallel_residual
|
||||
if isinstance(self.ln_f, RMSNorm)
|
||||
else dropout_add_layer_norm_parallel_residual)
|
||||
hidden_states, _ = fused_add_norm_fn(
|
||||
hidden_states, hidden_states2, residual, self.ln_f.weight, self.ln_f.bias,
|
||||
None, None, self.drop_f.p if self.training else 0.0, self.ln_f.eps,
|
||||
prenorm=False, residual_in_fp32=self.residual_in_fp32
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
||||
|
||||
def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__(config)
|
||||
self.process_group = process_group
|
||||
self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
|
||||
self.tie_word_embeddings = getattr(config, 'tie_word_embeddings', True)
|
||||
lm_head_bias = getattr(config, 'lm_head_bias', False)
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
|
||||
* pad_vocab_size_multiple)
|
||||
# This option is for OPT-350m
|
||||
word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None)
|
||||
embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim
|
||||
if word_embed_proj_dim is not None:
|
||||
self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs)
|
||||
else:
|
||||
self.project_out = None
|
||||
if process_group is None:
|
||||
self.lm_head = nn.Linear(embed_dim, vocab_size, bias=lm_head_bias, **factory_kwargs)
|
||||
else:
|
||||
if ColumnParallelLinear is None:
|
||||
raise ImportError('fused_dense_lib is not installed')
|
||||
self.lm_head = ColumnParallelLinear(
|
||||
embed_dim, vocab_size, process_group, bias=lm_head_bias,
|
||||
sequence_parallel=getattr(config, 'sequence_parallel', True), **factory_kwargs
|
||||
)
|
||||
# Initialize weights and apply final processing
|
||||
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
|
||||
initializer_range=config.initializer_range))
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
if self.tie_word_embeddings:
|
||||
self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
|
||||
if self.process_group is not None:
|
||||
sync_shared_params(self, self.process_group)
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
return self.transformer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype,
|
||||
**kwargs)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, inference_params=None, last_token_only=False):
|
||||
"""
|
||||
inference_params: for generation. Adapted from Megatron-LM (and Apex)
|
||||
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
||||
last_token_only: whether to return the logit for the last token only,
|
||||
of shape (batch_size, vocab_size)
|
||||
"""
|
||||
hidden_states = self.transformer(input_ids, position_ids=position_ids,
|
||||
inference_params=inference_params)
|
||||
if last_token_only:
|
||||
hidden_states = hidden_states[:, -1]
|
||||
if self.project_out is not None:
|
||||
hidden_states = self.project_out(hidden_states)
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
# During inference, we want the full logit for sampling
|
||||
if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
|
||||
lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
|
||||
lm_logits = rearrange(lm_logits, '(n b) ... d -> b ... (n d)', b=hidden_states.shape[0])
|
||||
CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
|
||||
return CausalLMOutput(logits=lm_logits)
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
# Remapping from our checkpoints that used a different ordering of layers in the block
|
||||
# Previous: Attn / MLP -> Dropout -> Add -> LN
|
||||
# Current: Dropout -> Add -> LN -> Attn / MLP
|
||||
if 'transformer.ln_0.weight' in state_dict:
|
||||
n_layers = len(self.transformer.layers)
|
||||
ln_weight = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.weight')
|
||||
ln_bias = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.bias')
|
||||
state_dict['transformer.ln_f.weight'] = ln_weight
|
||||
state_dict['transformer.ln_f.bias'] = ln_bias
|
||||
for l in reversed(range(n_layers)):
|
||||
ln_weight = state_dict.pop(f'transformer.layers.{l}.norm1.weight')
|
||||
ln_bias = state_dict.pop(f'transformer.layers.{l}.norm1.bias')
|
||||
state_dict[f'transformer.layers.{l}.norm2.weight'] = ln_weight
|
||||
state_dict[f'transformer.layers.{l}.norm2.bias'] = ln_bias
|
||||
if l > 0:
|
||||
ln_weight = state_dict.pop(f'transformer.layers.{l - 1}.norm2.weight')
|
||||
ln_bias = state_dict.pop(f'transformer.layers.{l - 1}.norm2.bias')
|
||||
state_dict[f'transformer.layers.{l}.norm1.weight'] = ln_weight
|
||||
state_dict[f'transformer.layers.{l}.norm1.bias'] = ln_bias
|
||||
ln_weight = state_dict.pop('transformer.ln_0.weight')
|
||||
ln_bias = state_dict.pop('transformer.ln_0.bias')
|
||||
state_dict[f'transformer.layers.0.norm1.weight'] = ln_weight
|
||||
state_dict[f'transformer.layers.0.norm1.bias'] = ln_bias
|
||||
return super().load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
def shard_state_dict_tp(state_dict, config, world_size, rank):
|
||||
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
|
||||
with tensor parallel.
|
||||
"""
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
|
||||
assert vocab_size % world_size == 0
|
||||
assert config.hidden_size % world_size == 0
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
|
||||
assert inner_dim % world_size == 0
|
||||
|
||||
def shard_first_dim(state_dict, key):
|
||||
if key in state_dict:
|
||||
x = state_dict[key]
|
||||
dim = x.shape[0] // world_size
|
||||
state_dict[key] = x[rank * dim:(rank + 1) * dim]
|
||||
|
||||
def shard_last_dim(state_dict, key):
|
||||
if key in state_dict:
|
||||
x = state_dict[key]
|
||||
dim = x.shape[-1] // world_size
|
||||
state_dict[key] = x[..., rank * dim:(rank + 1) * dim]
|
||||
|
||||
def shard_qkv_headdim(state_dict, key):
|
||||
if key in state_dict:
|
||||
n_head = config.n_head
|
||||
n_head_kv = getattr(config, 'n_head_kv', n_head)
|
||||
assert n_head % world_size == 0 and n_head_kv % world_size == 0
|
||||
if n_head_kv == n_head:
|
||||
x = rearrange(state_dict[key], '(three d) ... -> three d ...', three=3)
|
||||
dim = x.shape[1] // world_size
|
||||
state_dict[key] = rearrange(x[:, rank * dim:(rank + 1) * dim],
|
||||
'three d ... -> (three d) ...')
|
||||
else:
|
||||
n_head_per_rank = n_head // world_size
|
||||
n_head_kv_per_rank = n_head_kv // world_size
|
||||
x = rearrange(state_dict[key], '(nheadqkv headdim) ... -> nheadqkv headdim ...',
|
||||
nheadqkv=n_head + 2 * n_head_kv)
|
||||
state_dict[key] = rearrange(torch.cat([
|
||||
x[rank * n_head_per_rank:(rank + 1) * n_head_per_rank],
|
||||
x[n_head + rank * n_head_kv_per_rank:n_head + (rank + 1) * n_head_kv_per_rank],
|
||||
x[n_head + n_head_kv + rank * n_head_kv_per_rank:n_head + n_head_kv + (rank + 1) * n_head_kv_per_rank],
|
||||
], dim=0), "nheadqkv headdim ... -> (nheadqkv headdim) ...")
|
||||
|
||||
shard_first_dim(state_dict, 'transformer.embeddings.word_embeddings.weight')
|
||||
if 'lm_head.weight' in state_dict:
|
||||
shard_first_dim(state_dict, 'lm_head.weight')
|
||||
if 'transformer.embeddings.position_embeddings.weight' in state_dict:
|
||||
shard_last_dim(state_dict, 'transformer.embeddings.position_embeddings.weight')
|
||||
for i in range(config.num_hidden_layers):
|
||||
shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight')
|
||||
shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias')
|
||||
shard_last_dim(state_dict, f'transformer.layers.{i}.mixer.out_proj.weight')
|
||||
if rank != 0:
|
||||
state_dict.pop(f'transformer.layers.{i}.mixer.out_proj.bias', None)
|
||||
shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight')
|
||||
shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias')
|
||||
shard_last_dim(state_dict, f'transformer.layers.{i}.mlp.fc2.weight')
|
||||
if rank != 0:
|
||||
state_dict.pop(f'transformer.layers.{i}.mlp.fc2.bias', None)
|
||||
return state_dict
|
||||
|
||||
|
||||
def combine_state_dicts_tp(state_dicts, config):
|
||||
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
|
||||
with tensor parallel.
|
||||
"""
|
||||
world_size = len(state_dicts)
|
||||
keys = state_dicts[0].keys()
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
|
||||
assert vocab_size % world_size == 0
|
||||
assert config.hidden_size % world_size == 0
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
|
||||
assert inner_dim % world_size == 0
|
||||
|
||||
# Sometimes the word embeddings are sharded on the 0th dim, sometimes on the 1st dim.
|
||||
# vocab_size // world_size coordinates are nonzero.
|
||||
def combine_word_embeddings(state_dicts, state_dict, key):
|
||||
dim = 0 if state_dicts[0][key].shape[0] == vocab_size // world_size else 1
|
||||
state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
|
||||
|
||||
def combine_dim(state_dicts, state_dict, key, dim=-1):
|
||||
if key in state_dict:
|
||||
state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
|
||||
|
||||
def combine_qkv_headdim(state_dicts, state_dict, key):
|
||||
n_head = config.n_head
|
||||
n_head_kv = getattr(config, 'n_head_kv', n_head)
|
||||
assert n_head % world_size == 0 and n_head_kv % world_size == 0
|
||||
n_head_per_rank = n_head // world_size
|
||||
n_head_kv_per_rank = n_head_kv // world_size
|
||||
if key in state_dict:
|
||||
if n_head_kv == n_head:
|
||||
xs = [rearrange(s[key], '(three d) ... -> three d ...', three=3) for s in state_dicts]
|
||||
state_dict[key] = rearrange(torch.cat(xs, dim=1), 'three d ... -> (three d) ...')
|
||||
else:
|
||||
xs = [rearrange(s[key], '(nheadqkv headdim) ... -> nheadqkv headdim ...',
|
||||
nheadqkv=n_head + 2 * n_head_kv) for s in state_dicts]
|
||||
state_dict[key] = rearrange(torch.cat([
|
||||
torch.cat([x[:n_head_per_rank] for x in xs], dim=0),
|
||||
torch.cat([x[n_head_per_rank:n_head_per_rank + n_head_kv_per_rank] for x in xs], dim=0),
|
||||
torch.cat([x[-n_head_kv_per_rank:] for x in xs], dim=0),
|
||||
], dim=0), "nheadqkv headdim ... -> (nheadqkv headdim) ...")
|
||||
|
||||
def combine_gated_mlp(state_dicts, state_dict, key):
|
||||
if key in state_dict:
|
||||
xs = [rearrange(s[key], '(two d) ... -> two d ...', two=2) for s in state_dicts]
|
||||
state_dict[key] = rearrange(torch.cat(xs, dim=1), 'two d ... -> (two d) ...')
|
||||
|
||||
state_dict = state_dicts[0].copy() # don't modify state_dict[0] inplace
|
||||
combine_word_embeddings(state_dicts, state_dict, 'transformer.embeddings.word_embeddings.weight')
|
||||
if 'lm_head.weight' in state_dict:
|
||||
combine_word_embeddings(state_dicts, state_dict, 'lm_head.weight')
|
||||
if 'transformer.embeddings.position_embeddings.weight' in state_dict:
|
||||
combine_dim(state_dicts, state_dict, 'transformer.embeddings.position_embeddings.weight', -1)
|
||||
mlp_combine_fn = (combine_gated_mlp if config.activation_function in ['glu', 'swiglu', 'geglu']
|
||||
else partial(combine_dim, dim=0))
|
||||
for i in range(config.num_hidden_layers):
|
||||
combine_qkv_headdim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight')
|
||||
combine_qkv_headdim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias')
|
||||
combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.out_proj.weight', -1)
|
||||
mlp_combine_fn(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc1.weight')
|
||||
combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc1.bias', 0)
|
||||
combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc2.weight', -1)
|
||||
return state_dict
|
||||
|
||||
|
||||
def remap_state_dict_hf_gpt2(state_dict, config):
|
||||
# Word embedding and position embedding
|
||||
def key_mapping_pos_emb(key):
|
||||
return re.sub(r'^wpe.', 'transformer.embeddings.position_embeddings.', key)
|
||||
state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
|
||||
word_embeddings = state_dict.pop('wte.weight')
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
|
||||
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
|
||||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r'^ln_f.(weight|bias)', r'transformer.ln_f.\1', key)
|
||||
key = re.sub(r'^h.(\d+).ln_(1|2).(weight|bias)', r'transformer.layers.\1.norm\2.\3', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
# MLP
|
||||
for d in range(config.num_hidden_layers):
|
||||
W1 = state_dict.pop(f'h.{d}.mlp.c_fc.weight')
|
||||
state_dict[f'transformer.layers.{d}.mlp.fc1.weight'] = W1.t()
|
||||
W2 = state_dict.pop(f'h.{d}.mlp.c_proj.weight')
|
||||
state_dict[f'transformer.layers.{d}.mlp.fc2.weight'] = W2.t()
|
||||
def key_mapping_mlp(key):
|
||||
key = re.sub(r'^h.(\d+).mlp.c_fc.bias', r'transformer.layers.\1.mlp.fc1.bias', key)
|
||||
key = re.sub(r'^h.(\d+).mlp.c_proj.bias', r'transformer.layers.\1.mlp.fc2.bias', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Attention
|
||||
for d in range(config.num_hidden_layers):
|
||||
state_dict.pop(f'h.{d}.attn.bias') # We don't store this bias
|
||||
Wqkv = state_dict.pop(f'h.{d}.attn.c_attn.weight')
|
||||
state_dict[f'transformer.layers.{d}.mixer.Wqkv.weight'] = Wqkv.t()
|
||||
Wout = state_dict.pop(f'h.{d}.attn.c_proj.weight')
|
||||
state_dict[f'transformer.layers.{d}.mixer.out_proj.weight'] = Wout.t()
|
||||
def key_mapping_attn(key):
|
||||
key = re.sub(r'^h.(\d+).attn.c_attn.bias', r'transformer.layers.\1.mixer.Wqkv.bias', key)
|
||||
key = re.sub(r'^h.(\d+).attn.c_proj.bias', r'transformer.layers.\1.mixer.out_proj.bias', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def remap_state_dict_megatron(state_dict, config):
|
||||
def key_mapping_transformer(key):
|
||||
key = re.sub(r'^language_model.encoder.', 'transformer.', key)
|
||||
key = re.sub(r'^language_model.', 'transformer.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items())
|
||||
# Word embedding and position embedding
|
||||
def key_mapping_pos_emb(key):
|
||||
return re.sub(r'^wpe.', 'transformer.embeddings.position_embeddings.', key)
|
||||
state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
|
||||
word_embeddings = state_dict.pop('transformer.embedding.word_embeddings.weight')
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
|
||||
* pad_vocab_size_multiple)
|
||||
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
|
||||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r'^transformer.final_layernorm.(weight|bias)', r'transformer.ln_f.\1', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).input_layernorm.(weight|bias)',
|
||||
r'transformer.layers.\1.norm1.\2', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.(weight|bias)',
|
||||
r'transformer.layers.\1.norm2.\2', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
# MLP
|
||||
def key_mapping_mlp(key):
|
||||
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.(weight|bias)',
|
||||
r'transformer.layers.\1.mlp.fc1.\2', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.(weight|bias)',
|
||||
r'transformer.layers.\1.mlp.fc2.\2', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Attention
|
||||
def key_mapping_attn(key):
|
||||
key = re.sub(r'^transformer.layers.(\d+).self_attention.rotary_emb.inv_freq',
|
||||
r'transformer.layers.\1.mixer.rotary_emb.inv_freq', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).self_attention.query_key_value.(weight|bias)',
|
||||
r'transformer.layers.\1.mixer.Wqkv.\2', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).self_attention.dense.(weight|bias)',
|
||||
r'transformer.layers.\1.mixer.out_proj.\2', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
# Megatron stores Wqkv as ((nheads 3 headdim), hidden_dim)
|
||||
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
|
||||
headdim = config.hidden_size // config.num_attention_heads
|
||||
for d in range(config.num_hidden_layers):
|
||||
Wqkv = state_dict.pop(f'transformer.layers.{d}.mixer.Wqkv.weight')
|
||||
state_dict[f'transformer.layers.{d}.mixer.Wqkv.weight'] = rearrange(
|
||||
Wqkv, '(nheads three headdim) ... -> (three nheads headdim) ...',
|
||||
three=3, headdim=headdim
|
||||
)
|
||||
bqkv = state_dict.pop(f'transformer.layers.{d}.mixer.Wqkv.bias')
|
||||
state_dict[f'transformer.layers.{d}.mixer.Wqkv.bias'] = rearrange(
|
||||
bqkv, '(nheads three headdim) -> (three nheads headdim)',
|
||||
three=3, headdim=headdim
|
||||
)
|
||||
|
||||
return state_dict
|
||||
107
pkgs/xformers/_flash_attn/models/gpt_neox.py
Normal file
107
pkgs/xformers/_flash_attn/models/gpt_neox.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
import math
|
||||
import re
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from transformers import GPT2Config, GPTNeoXConfig
|
||||
|
||||
|
||||
def remap_state_dict_hf_gpt_neox(state_dict, config):
|
||||
def key_mapping_layers(key):
|
||||
return re.sub(r'^gpt_neox.', 'transformer.', key)
|
||||
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
||||
# Word embedding
|
||||
def key_mapping_emb(key):
|
||||
return re.sub(r'^transformer.embed_in.', 'transformer.embeddings.word_embeddings.', key)
|
||||
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
||||
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
|
||||
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
|
||||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
if getattr(config, 'tie_word_embeddings'):
|
||||
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
|
||||
else:
|
||||
output_embeddings = state_dict.pop('embed_out.weight')
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
state_dict['lm_head.weight'] = F.pad(
|
||||
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
|
||||
)
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).input_layernorm.', r'transformer.layers.\1.norm1.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.', r'transformer.layers.\1.norm2.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
# MLP
|
||||
def key_mapping_mlp(key):
|
||||
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.', r'transformer.layers.\1.mlp.fc1.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.', r'transformer.layers.\1.mlp.fc2.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Attention
|
||||
for l in range(config.n_layer):
|
||||
# We don't store these biases
|
||||
state_dict.pop(f'transformer.layers.{l}.attention.bias')
|
||||
state_dict.pop(f'transformer.layers.{l}.attention.masked_bias')
|
||||
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
|
||||
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
|
||||
headdim = config.hidden_size // config.num_attention_heads
|
||||
Wqkv = state_dict.pop(f'transformer.layers.{l}.attention.query_key_value.weight')
|
||||
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = rearrange(
|
||||
Wqkv, '(nheads three headdim) ... -> (three nheads headdim) ...',
|
||||
three=3, headdim=headdim
|
||||
)
|
||||
bqkv = state_dict.pop(f'transformer.layers.{l}.attention.query_key_value.bias')
|
||||
state_dict[f'transformer.layers.{l}.mixer.Wqkv.bias'] = rearrange(
|
||||
bqkv, '(nheads three headdim) -> (three nheads headdim)',
|
||||
three=3, headdim=headdim
|
||||
)
|
||||
def key_mapping_attn(key):
|
||||
key = re.sub(r'^transformer.layers.(\d+).attention.dense.',
|
||||
r'transformer.layers.\1.mixer.out_proj.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).attention.rotary_emb.',
|
||||
r'transformer.layers.\1.mixer.rotary_emb.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def gpt_neox_config_to_gpt2_config(gpt_neox_config: GPTNeoXConfig) -> GPT2Config:
|
||||
assert gpt_neox_config.rotary_emb_base == 10000
|
||||
return GPT2Config(
|
||||
vocab_size=gpt_neox_config.vocab_size,
|
||||
n_positions=0, # No absolute position embedding
|
||||
n_embd=gpt_neox_config.hidden_size,
|
||||
n_layer=gpt_neox_config.num_hidden_layers,
|
||||
n_head=gpt_neox_config.num_attention_heads,
|
||||
n_inner=gpt_neox_config.intermediate_size,
|
||||
activation_function=gpt_neox_config.hidden_act,
|
||||
resid_pdrop=0.0, # No dropout
|
||||
embd_pdrop=0.0,
|
||||
attn_pdrop=0.0,
|
||||
layer_norm_epsilon=gpt_neox_config.layer_norm_eps,
|
||||
initializer_range=gpt_neox_config.initializer_range,
|
||||
bos_token_id=gpt_neox_config.bos_token_id,
|
||||
eos_token_id=gpt_neox_config.eos_token_id,
|
||||
# These are new arguments not in the original GPT2Config
|
||||
prenorm=True,
|
||||
parallel_block=gpt_neox_config.use_parallel_residual,
|
||||
parallel_block_tied_norm=False,
|
||||
rotary_emb_fraction=gpt_neox_config.rotary_pct,
|
||||
tie_word_embeddings=gpt_neox_config.tie_word_embeddings,
|
||||
)
|
||||
98
pkgs/xformers/_flash_attn/models/gptj.py
Normal file
98
pkgs/xformers/_flash_attn/models/gptj.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
import math
|
||||
import re
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import GPT2Config, GPTJConfig
|
||||
|
||||
|
||||
def remap_state_dict_hf_gptj(state_dict, config):
|
||||
def key_mapping_layers(key):
|
||||
return re.sub(r'^transformer.h.', 'transformer.layers.', key)
|
||||
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
||||
# Word embedding
|
||||
def key_mapping_emb(key):
|
||||
return re.sub(r'^transformer.wte.', 'transformer.embeddings.word_embeddings.', key)
|
||||
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
||||
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
|
||||
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
|
||||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
if getattr(config, 'tie_word_embeddings'):
|
||||
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
|
||||
else:
|
||||
output_embeddings = state_dict.pop('lm_head.weight')
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
state_dict['lm_head.weight'] = F.pad(
|
||||
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
|
||||
)
|
||||
output_embeddings_bias = state_dict.pop('lm_head.bias')
|
||||
state_dict['lm_head.bias'] = F.pad(
|
||||
output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])
|
||||
)
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
return re.sub(r'^transformer.layers.(\d+).ln_1.', r'transformer.layers.\1.norm1.', key)
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
# MLP
|
||||
def key_mapping_mlp(key):
|
||||
key = re.sub(r'^transformer.layers.(\d+).mlp.fc_in.', r'transformer.layers.\1.mlp.fc1.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).mlp.fc_out.', r'transformer.layers.\1.mlp.fc2.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Attention
|
||||
for l in range(config.n_layer):
|
||||
Wq = state_dict.pop(f'transformer.layers.{l}.attn.q_proj.weight')
|
||||
Wk = state_dict.pop(f'transformer.layers.{l}.attn.k_proj.weight')
|
||||
Wv = state_dict.pop(f'transformer.layers.{l}.attn.v_proj.weight')
|
||||
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
|
||||
# We don't store these biases
|
||||
state_dict.pop(f'transformer.layers.{l}.attn.bias')
|
||||
state_dict.pop(f'transformer.layers.{l}.attn.masked_bias')
|
||||
def key_mapping_attn(key):
|
||||
return re.sub(r'^transformer.layers.(\d+).attn.out_proj.',
|
||||
r'transformer.layers.\1.mixer.out_proj.', key)
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def gptj_config_to_gpt2_config(gptj_config: GPTJConfig) -> GPT2Config:
|
||||
headdim = gptj_config.n_embd // gptj_config.n_head
|
||||
return GPT2Config(
|
||||
vocab_size=gptj_config.vocab_size,
|
||||
n_positions=0, # No absolute position embedding
|
||||
n_embd=gptj_config.n_embd,
|
||||
n_layer=gptj_config.n_layer,
|
||||
n_head=gptj_config.n_head,
|
||||
n_inner=gptj_config.n_inner,
|
||||
activation_function=gptj_config.activation_function,
|
||||
resid_pdrop=gptj_config.resid_pdrop,
|
||||
embd_pdrop=gptj_config.embd_pdrop,
|
||||
attn_pdrop=gptj_config.attn_pdrop,
|
||||
layer_norm_epsilon=gptj_config.layer_norm_epsilon,
|
||||
initializer_range=gptj_config.initializer_range,
|
||||
bos_token_id=gptj_config.bos_token_id,
|
||||
eos_token_id=gptj_config.eos_token_id,
|
||||
# These are new arguments not in the original GPT2Config
|
||||
prenorm=True,
|
||||
parallel_block=True,
|
||||
parallel_block_tied_norm=True,
|
||||
rotary_emb_fraction=gptj_config.rotary_dim / headdim,
|
||||
rotary_emb_interleaved=True,
|
||||
tie_word_embeddings=False,
|
||||
qkv_proj_bias=False,
|
||||
out_proj_bias=False,
|
||||
lm_head_bias=True,
|
||||
)
|
||||
124
pkgs/xformers/_flash_attn/models/llama.py
Normal file
124
pkgs/xformers/_flash_attn/models/llama.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
import math
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import GPT2Config, LlamaConfig
|
||||
|
||||
|
||||
def remap_state_dict_meta_llama(state_dict, config):
|
||||
def key_mapping_layers(key):
|
||||
return f'transformer.{key}' if not key.startswith('output.') else key
|
||||
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
||||
# Word embedding
|
||||
def key_mapping_emb(key):
|
||||
return re.sub(r'^transformer.tok_embeddings.', 'transformer.embeddings.word_embeddings.', key)
|
||||
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
||||
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
|
||||
* pad_vocab_size_multiple)
|
||||
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
|
||||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
if getattr(config, 'tie_word_embeddings'):
|
||||
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
|
||||
else:
|
||||
output_embeddings = state_dict.pop('output.weight')
|
||||
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
|
||||
# differently.
|
||||
vocab_size = (math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
|
||||
* pad_vocab_size_multiple)
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
state_dict['lm_head.weight'] = F.pad(
|
||||
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
|
||||
)
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r'^transformer.norm.', r'transformer.ln_f.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).attention_norm.', r'transformer.layers.\1.norm1.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).ffn_norm.', r'transformer.layers.\1.norm2.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
# MLP
|
||||
for l in range(config.n_layer):
|
||||
w1 = state_dict.pop(f'transformer.layers.{l}.feed_forward.w1.weight')
|
||||
w3 = state_dict.pop(f'transformer.layers.{l}.feed_forward.w3.weight')
|
||||
# Our ordering is different
|
||||
state_dict[f'transformer.layers.{l}.mlp.fc1.weight'] = torch.cat([w3, w1], dim=0)
|
||||
def key_mapping_mlp(key):
|
||||
return re.sub(r'^transformer.layers.(\d+).feed_forward.w2.',
|
||||
r'transformer.layers.\1.mlp.fc2.', key)
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Attention
|
||||
for l in range(config.n_layer):
|
||||
Wq = state_dict.pop(f'transformer.layers.{l}.attention.wq.weight')
|
||||
Wk = state_dict.pop(f'transformer.layers.{l}.attention.wk.weight')
|
||||
Wv = state_dict.pop(f'transformer.layers.{l}.attention.wv.weight')
|
||||
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
|
||||
# We don't store these
|
||||
state_dict.pop(f'transformer.layers.{l}.attention.inner_attention.rope.freqs', None)
|
||||
def key_mapping_attn(key):
|
||||
return re.sub(r'^transformer.layers.(\d+).attention.wo.',
|
||||
r'transformer.layers.\1.mixer.out_proj.', key)
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def config_from_checkpoint(checkpoint_path: str, model_name: str) -> LlamaConfig:
|
||||
"""Load a LlamaConfig from a checkpoint path."""
|
||||
with open(Path(checkpoint_path) / model_name / 'params.json') as f:
|
||||
params = json.load(f)
|
||||
config = LlamaConfig(hidden_size=params['dim'], intermediate_size=None,
|
||||
num_attention_heads=params['n_heads'],
|
||||
num_hidden_layers=params['n_layers'],
|
||||
rms_norm_eps=params['norm_eps'])
|
||||
return config
|
||||
|
||||
|
||||
def state_dicts_from_checkpoint(checkpoint_path: str, model_name: str) -> dict:
|
||||
# Need to sort, otherwise we mess up the ordering and the weights are wrong
|
||||
return [torch.load(path, map_location='cpu')
|
||||
for path in sorted((Path(checkpoint_path) / model_name).glob('consolidated.*.pth'))]
|
||||
|
||||
|
||||
def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
|
||||
return GPT2Config(
|
||||
vocab_size=llama_config.vocab_size,
|
||||
n_positions=0, # No absolute position embedding
|
||||
n_embd=llama_config.hidden_size,
|
||||
n_layer=llama_config.num_hidden_layers,
|
||||
n_head=llama_config.num_attention_heads,
|
||||
n_inner=llama_config.intermediate_size,
|
||||
activation_function='swiglu', # Hardcode since HF calls it 'silu'
|
||||
# Llama doesn't have dropout, idk if it's because they only release the inference code
|
||||
resid_pdrop=0.0,
|
||||
embd_pdrop=0.0,
|
||||
attn_pdrop=0.0,
|
||||
layer_norm_epsilon=llama_config.rms_norm_eps,
|
||||
initializer_range=llama_config.initializer_range,
|
||||
bos_token_id=llama_config.bos_token_id,
|
||||
eos_token_id=llama_config.eos_token_id,
|
||||
# These are new arguments not in the original GPT2Config
|
||||
pad_token_id=llama_config.pad_token_id, # Idk if this does anything
|
||||
rms_norm=True,
|
||||
rotary_emb_fraction=1.0,
|
||||
rotary_emb_interleaved=True,
|
||||
tie_word_embeddings=False,
|
||||
qkv_proj_bias=False,
|
||||
out_proj_bias=False,
|
||||
mlp_fc1_bias=False,
|
||||
mlp_fc2_bias=False,
|
||||
)
|
||||
102
pkgs/xformers/_flash_attn/models/opt.py
Normal file
102
pkgs/xformers/_flash_attn/models/opt.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
import math
|
||||
import re
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import GPT2Config, OPTConfig
|
||||
|
||||
|
||||
def remap_state_dict_hf_opt(state_dict, config):
|
||||
def key_mapping_model(key):
|
||||
key = re.sub(r'^model.decoder.', 'transformer.', key)
|
||||
# The OPT-350m model uses '^decoder' instead of '^model.decoder'
|
||||
key = re.sub(r'^decoder.', 'transformer.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_model(k), v) for k, v in state_dict.items())
|
||||
# Word embedding and position embedding
|
||||
def key_mapping_emb(key):
|
||||
key = re.sub(r'^transformer.embed_tokens.', 'transformer.embeddings.word_embeddings.', key)
|
||||
# The OPT-350m model uses has project_in and project_out
|
||||
key = re.sub(r'^transformer.project_in.', 'transformer.embeddings.project_in.', key)
|
||||
key = re.sub(r'^transformer.project_out.', 'project_out.', key)
|
||||
key = re.sub(r'^transformer.embed_positions.',
|
||||
'transformer.embeddings.position_embeddings.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
||||
# OPT uses the first 2 indices of pos_emb for padding tokens
|
||||
pos_embeddings = state_dict.pop('transformer.embeddings.position_embeddings.weight')
|
||||
state_dict['transformer.embeddings.position_embeddings.weight'] = pos_embeddings[2:]
|
||||
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
|
||||
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
|
||||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key)
|
||||
# The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm'
|
||||
key = re.sub(r'^transformer.layer_norm.', r'transformer.ln_f.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).self_attn_layer_norm.',
|
||||
r'transformer.layers.\1.norm1.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).final_layer_norm.',
|
||||
r'transformer.layers.\1.norm2.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
# MLP
|
||||
def key_mapping_mlp(key):
|
||||
return re.sub(r'^transformer.layers.(\d+).fc(1|2).',
|
||||
r'transformer.layers.\1.mlp.fc\2.', key)
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Attention
|
||||
for l in range(config.n_layer):
|
||||
Wq = state_dict.pop(f'transformer.layers.{l}.self_attn.q_proj.weight')
|
||||
Wk = state_dict.pop(f'transformer.layers.{l}.self_attn.k_proj.weight')
|
||||
Wv = state_dict.pop(f'transformer.layers.{l}.self_attn.v_proj.weight')
|
||||
bq = state_dict.pop(f'transformer.layers.{l}.self_attn.q_proj.bias')
|
||||
bk = state_dict.pop(f'transformer.layers.{l}.self_attn.k_proj.bias')
|
||||
bv = state_dict.pop(f'transformer.layers.{l}.self_attn.v_proj.bias')
|
||||
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
|
||||
state_dict[f'transformer.layers.{l}.mixer.Wqkv.bias'] = torch.cat([bq, bk, bv], dim=0)
|
||||
def key_mapping_attn(key):
|
||||
return re.sub(r'^transformer.layers.(\d+).self_attn.out_proj.',
|
||||
r'transformer.layers.\1.mixer.out_proj.', key)
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config:
|
||||
assert opt_config.layerdrop == 0.0
|
||||
assert opt_config.layer_norm_elementwise_affine
|
||||
word_embed_proj_dim = (None if opt_config.word_embed_proj_dim == opt_config.hidden_size
|
||||
else opt_config.word_embed_proj_dim)
|
||||
return GPT2Config(
|
||||
vocab_size=opt_config.vocab_size,
|
||||
n_positions=opt_config.max_position_embeddings,
|
||||
n_embd=opt_config.hidden_size,
|
||||
n_layer=opt_config.num_hidden_layers,
|
||||
n_head=opt_config.num_attention_heads,
|
||||
n_inner=opt_config.ffn_dim,
|
||||
activation_function=opt_config.activation_function,
|
||||
resid_pdrop=opt_config.dropout,
|
||||
# HF's implementation of OPT doesn't seem to have embedding dropout
|
||||
embd_pdrop=opt_config.dropout,
|
||||
attn_pdrop=opt_config.attention_dropout,
|
||||
initializer_range=opt_config.init_std,
|
||||
bos_token_id=opt_config.bos_token_id,
|
||||
eos_token_id=opt_config.eos_token_id,
|
||||
# These are new arguments not in the original GPT2Config
|
||||
prenorm=opt_config.do_layer_norm_before,
|
||||
word_embed_proj_dim=word_embed_proj_dim
|
||||
)
|
||||
304
pkgs/xformers/_flash_attn/models/vit.py
Normal file
304
pkgs/xformers/_flash_attn/models/vit.py
Normal file
@@ -0,0 +1,304 @@
|
||||
# Copyright (c) 2022, Tri Dao.
|
||||
# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
import math
|
||||
import re
|
||||
from functools import partial
|
||||
from copy import deepcopy
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.init import trunc_normal_
|
||||
|
||||
from torchvision.ops import StochasticDepth
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from timm.models.helpers import named_apply
|
||||
from flash_attn.layers.patch_embed import PatchEmbed
|
||||
|
||||
from flash_attn.modules.mha import MHA
|
||||
from flash_attn.modules.mlp import Mlp, FusedMLP
|
||||
from flash_attn.modules.block import Block
|
||||
|
||||
try:
|
||||
from flash_attn.ops.layer_norm import dropout_add_layer_norm
|
||||
except ImportError:
|
||||
dropout_add_layer_norm = None
|
||||
|
||||
|
||||
def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc,
|
||||
cross_attn=False):
|
||||
mixer_cls = partial(MHA, num_heads=num_heads, cross_attn=cross_attn, bias=qkv_bias,
|
||||
dropout=attn_drop, fused_bias_fc=fused_bias_fc,
|
||||
use_flash_attn=use_flash_attn)
|
||||
return mixer_cls
|
||||
|
||||
|
||||
def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp):
|
||||
inner_dim = int(embed_dim * mlp_ratio)
|
||||
if not fused_mlp:
|
||||
mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=act_layer())
|
||||
else:
|
||||
mlp_cls = partial(FusedMLP, hidden_features=inner_dim)
|
||||
return mlp_cls
|
||||
|
||||
|
||||
def create_block(embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate,
|
||||
drop_path1, drop_path2, norm_layer, act_layer, use_flash_attn, fused_bias_fc,
|
||||
fused_mlp, fused_dropout_add_ln, layer_idx=None, n_layer=None,
|
||||
last_layer_subset=False):
|
||||
mixer_cls = create_mixer_cls(num_heads, qkv_bias, attn_drop_rate, use_flash_attn, fused_bias_fc,
|
||||
cross_attn=(last_layer_subset and layer_idx == n_layer - 1))
|
||||
mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp)
|
||||
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
|
||||
block = Block(embed_dim, mixer_cls, mlp_cls, norm_cls=norm_layer,
|
||||
prenorm=True, resid_dropout1=drop_rate, resid_dropout2=drop_rate,
|
||||
drop_path1=drop_path1, drop_path2=drop_path2,
|
||||
fused_dropout_add_ln=fused_dropout_add_ln, residual_in_fp32=True)
|
||||
return block
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
""" Vision Transformer
|
||||
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
||||
- https://arxiv.org/abs/2010.11929
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool='token',
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
init_values=None,
|
||||
class_token=True,
|
||||
no_embed_class=False,
|
||||
pre_norm=False,
|
||||
fc_norm=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
weight_init='',
|
||||
embed_layer=PatchEmbed,
|
||||
norm_layer=None,
|
||||
act_layer=None,
|
||||
use_flash_attn=False,
|
||||
fused_bias_fc=False,
|
||||
fused_mlp=False,
|
||||
fused_dropout_add_ln=False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
img_size (int, tuple): input image size
|
||||
patch_size (int, tuple): patch size
|
||||
in_chans (int): number of input channels
|
||||
num_classes (int): number of classes for classification head
|
||||
global_pool (str): type of global pooling for final sequence (default: 'token')
|
||||
embed_dim (int): embedding dimension
|
||||
depth (int): depth of transformer
|
||||
num_heads (int): number of attention heads
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
||||
qkv_bias (bool): enable bias for qkv if True
|
||||
init_values: (float): layer-scale init values
|
||||
class_token (bool): use class token
|
||||
fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
|
||||
drop_rate (float): dropout rate
|
||||
attn_drop_rate (float): attention dropout rate
|
||||
drop_path_rate (float): stochastic depth rate
|
||||
weight_init (str): weight init scheme
|
||||
embed_layer (nn.Module): patch embedding layer
|
||||
norm_layer: (nn.Module): normalization layer
|
||||
act_layer: (nn.Module): MLP activation layer
|
||||
"""
|
||||
super().__init__()
|
||||
assert global_pool == 'token', 'Only support pooling with CLS token'
|
||||
assert class_token
|
||||
assert init_values is None, 'LayerScale is not supported yet'
|
||||
assert weight_init == ''
|
||||
assert fc_norm is None
|
||||
# pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk
|
||||
assert not pre_norm
|
||||
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
|
||||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||
act_layer = act_layer or nn.GELU
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.num_prefix_tokens = 1 if class_token else 0
|
||||
self.no_embed_class = no_embed_class
|
||||
|
||||
patch_embed_extra_kwargs = ({'fused_bias_fc': fused_bias_fc} if embed_layer is PatchEmbed
|
||||
else {})
|
||||
self.patch_embed = embed_layer(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
|
||||
**patch_embed_extra_kwargs
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
||||
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
|
||||
# We change the order of dropout, residual and layer norm:
|
||||
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
|
||||
# Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
|
||||
# the main branch (output of MLP). The model definition is unchanged, but the mapping of the
|
||||
# nn.Dropout probabilities are changed.
|
||||
# This is for performance reason: we can fuse dropout + add + layer_norm.
|
||||
self.blocks = nn.ModuleList([create_block(
|
||||
embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate,
|
||||
drop_path1=dpr[i-1] if i > 0 else 0., drop_path2=dpr[i],
|
||||
norm_layer=norm_layer, act_layer=act_layer, use_flash_attn=use_flash_attn,
|
||||
fused_bias_fc=fused_bias_fc, fused_mlp=fused_mlp,
|
||||
fused_dropout_add_ln=fused_dropout_add_ln, layer_idx=i, n_layer=depth,
|
||||
last_layer_subset=(global_pool == 'token')
|
||||
) for i in range(depth)])
|
||||
|
||||
self.dropout = nn.Dropout(p=drop_rate)
|
||||
self.drop_path = StochasticDepth(p=dpr[-1], mode='row')
|
||||
self.norm = norm_layer(embed_dim)
|
||||
|
||||
self.fused_dropout_add_ln = fused_dropout_add_ln
|
||||
if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
|
||||
raise ImportError('dropout_add_layer_norm is not installed')
|
||||
|
||||
# Classifier Head
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
self.init_weights(weight_init)
|
||||
|
||||
def init_weights(self, mode=''):
|
||||
assert mode == ''
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
if self.cls_token is not None:
|
||||
nn.init.normal_(self.cls_token, std=1e-6)
|
||||
named_apply(init_weights_vit_timm, self)
|
||||
|
||||
def _init_weights(self, m):
|
||||
# this fn left here for compat with downstream users
|
||||
init_weights_vit_timm(m)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'pos_embed', 'cls_token'}
|
||||
|
||||
def _pos_embed(self, x):
|
||||
if self.no_embed_class:
|
||||
# deit-3, updated JAX (big vision)
|
||||
# position embedding does not overlap with class token, add then concat
|
||||
x = x + self.pos_embed
|
||||
if self.cls_token is not None:
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
else:
|
||||
# original timm, JAX, and deit vit impl
|
||||
# pos_embed has entry for class token, concat then add
|
||||
if self.cls_token is not None:
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + self.pos_embed
|
||||
return x
|
||||
|
||||
def forward_features(self, x, all_tokens=True):
|
||||
"""
|
||||
If all_tokens==False and self.global_pool == 'token', we only return the features for the
|
||||
cls token.
|
||||
"""
|
||||
x = self.patch_embed(x)
|
||||
hidden_states = self._pos_embed(x)
|
||||
residual = None
|
||||
if self.global_pool != 'token' or all_tokens:
|
||||
# if True:
|
||||
for block in self.blocks:
|
||||
hidden_states, residual = block(hidden_states, residual)
|
||||
else:
|
||||
for block in self.blocks[:-1]:
|
||||
hidden_states, residual = block(hidden_states, residual)
|
||||
# For the last layer, we only want the 1st token of the output. So we do cross-attention
|
||||
# where the query is the 1st token and the key/value is the whole sequence.
|
||||
hidden_states, residual = self.blocks[-1](hidden_states, residual,
|
||||
mixer_subset=slice(0, 1))
|
||||
if not self.fused_dropout_add_ln:
|
||||
residual = self.drop_path(self.dropout(hidden_states)) + residual
|
||||
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
||||
else:
|
||||
if self.drop_path.p == 0 or not self.training:
|
||||
rowscale = None
|
||||
else:
|
||||
rowscale = self.drop_path(torch.ones(
|
||||
hidden_states.shape[:-1], device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
)
|
||||
# Set prenorm=False here since we don't need to the residual
|
||||
hidden_states = dropout_add_layer_norm(
|
||||
hidden_states, residual, self.norm.weight, self.norm.bias,
|
||||
self.dropout.p if self.training else 0.0, self.norm.eps, rowscale=rowscale,
|
||||
prenorm=False, residual_in_fp32=True
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool:
|
||||
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x, all_tokens=False)
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
patch_embed_weight = state_dict['patch_embed.proj.weight']
|
||||
if patch_embed_weight.dim() == 4:
|
||||
# convert from Conv2d to Linear
|
||||
state_dict['patch_embed.proj.weight'] = rearrange(patch_embed_weight,
|
||||
'o c h w -> o (c h w)')
|
||||
def key_mapping_attn(key):
|
||||
key = re.sub(r'^blocks.(\d+).attn.qkv.', r'blocks.\1.mixer.Wqkv.', key)
|
||||
key = re.sub(r'^blocks.(\d+).attn.proj.', r'blocks.\1.mixer.out_proj.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
n_layer = len(self.blocks)
|
||||
# Convert from Wqkv to Wq and Wkv for cross attention (last layer)
|
||||
if (self.blocks[-1].mixer.cross_attn
|
||||
and f'blocks.{n_layer - 1}.mixer.Wqkv.weight' in state_dict):
|
||||
Wqkv = state_dict.pop(f'blocks.{n_layer - 1}.mixer.Wqkv.weight')
|
||||
bqkv = state_dict.pop(f'blocks.{n_layer - 1}.mixer.Wqkv.bias')
|
||||
state_dict[f'blocks.{n_layer - 1}.mixer.Wq.weight'] = Wqkv[:self.embed_dim]
|
||||
state_dict[f'blocks.{n_layer - 1}.mixer.Wkv.weight'] = Wqkv[self.embed_dim:]
|
||||
state_dict[f'blocks.{n_layer - 1}.mixer.Wq.bias'] = bqkv[:self.embed_dim]
|
||||
state_dict[f'blocks.{n_layer - 1}.mixer.Wkv.bias'] = bqkv[self.embed_dim:]
|
||||
return super().load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
def init_weights_vit_timm(module: nn.Module, name: str = ''):
|
||||
""" ViT weight initialization, original timm impl (for reproducibility) """
|
||||
if isinstance(module, nn.Linear):
|
||||
trunc_normal_(module.weight, std=.02)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif hasattr(module, 'init_weights'):
|
||||
module.init_weights()
|
||||
|
||||
|
||||
def vit_base_patch16_224(pretrained=False, **kwargs):
|
||||
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
assert not pretrained
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = VisionTransformer(**model_kwargs)
|
||||
return model
|
||||
0
pkgs/xformers/_flash_attn/modules/__init__.py
Normal file
0
pkgs/xformers/_flash_attn/modules/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
324
pkgs/xformers/_flash_attn/modules/block.py
Normal file
324
pkgs/xformers/_flash_attn/modules/block.py
Normal file
@@ -0,0 +1,324 @@
|
||||
# Copyright (c) 2022, Tri Dao.
|
||||
|
||||
from typing import Optional
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from torchvision.ops import StochasticDepth
|
||||
|
||||
from flash_attn.modules.mha import MHA
|
||||
from flash_attn.modules.mlp import Mlp
|
||||
|
||||
try:
|
||||
from flash_attn.ops.layer_norm import dropout_add_layer_norm
|
||||
except ImportError:
|
||||
dropout_add_layer_norm = None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
|
||||
except ImportError:
|
||||
dropout_add_layer_norm_parallel_residual = None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm
|
||||
except ImportError:
|
||||
RMSNorm, dropout_add_rms_norm = None, None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
|
||||
except ImportError:
|
||||
dropout_add_rms_norm_parallel_residual = None
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
|
||||
def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
|
||||
dropout_cls=nn.Dropout, prenorm=True, resid_dropout1=0., resid_dropout2=0.,
|
||||
drop_path1=0., drop_path2=0., fused_dropout_add_ln=False, return_residual=False,
|
||||
residual_in_fp32=False, sequence_parallel=False, mark_shared_params=False):
|
||||
"""
|
||||
For prenorm=True, this Block has a slightly different structure compared to a regular
|
||||
prenorm Transformer block.
|
||||
The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
|
||||
[Ref: https://arxiv.org/abs/2002.04745]
|
||||
Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
|
||||
the hidden_states (output of the MLP) and the residual.
|
||||
This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
|
||||
The residual needs to be provided (except for the very first block).
|
||||
|
||||
For prenorm=False, this Block has the same structure as a regular postnorm Transformer
|
||||
block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
|
||||
|
||||
return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
|
||||
This is for performance reason: for post-norm architecture, returning the input allows us
|
||||
to fuse the backward of nn.Linear with the residual connection.
|
||||
"""
|
||||
super().__init__()
|
||||
self.prenorm = prenorm
|
||||
self.fused_dropout_add_ln = fused_dropout_add_ln
|
||||
self.return_residual = return_residual
|
||||
self.residual_in_fp32 = residual_in_fp32
|
||||
if self.residual_in_fp32:
|
||||
assert self.prenorm, 'residual_in_fp32 is only compatible with prenorm=True'
|
||||
if mixer_cls is None:
|
||||
mixer_cls = partial(MHA, num_heads=dim // 64)
|
||||
if mlp_cls is None:
|
||||
mlp_cls = partial(Mlp, hidden_features=4 * dim)
|
||||
self.mixer = mixer_cls(dim)
|
||||
self.dropout1 = dropout_cls(resid_dropout1)
|
||||
self.drop_path1 = StochasticDepth(drop_path1, mode='row')
|
||||
self.norm1 = norm_cls(dim)
|
||||
self.mlp = mlp_cls(dim)
|
||||
if not isinstance(self.mlp, nn.Identity):
|
||||
self.dropout2 = dropout_cls(resid_dropout2)
|
||||
self.drop_path2 = StochasticDepth(drop_path2, mode='row')
|
||||
self.norm2 = norm_cls(dim)
|
||||
|
||||
if self.fused_dropout_add_ln:
|
||||
assert dropout_add_layer_norm is not None, 'dropout_layer_norm is not installed'
|
||||
assert dropout_add_rms_norm is not None, 'dropout_layer_norm is not installed'
|
||||
assert (isinstance(self.norm1, (nn.LayerNorm, RMSNorm))
|
||||
and isinstance(self.dropout1, nn.Dropout))
|
||||
|
||||
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
|
||||
# then the input to each worker in the tensor parallel group will be different.
|
||||
# This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
|
||||
# For now this is not an issue because we always use sequence_parallel=True during training
|
||||
# and only use sequence_parallel=False during inference.
|
||||
|
||||
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
|
||||
if sequence_parallel:
|
||||
for p in self.norm1.parameters():
|
||||
p._sequence_parallel = True
|
||||
if hasattr(self, 'norm2'):
|
||||
for p in self.norm2.parameters():
|
||||
p._sequence_parallel = True
|
||||
# Mark the norm parameters as "shared_params" so that we sync their values at init.
|
||||
if mark_shared_params:
|
||||
for p in self.norm1.parameters():
|
||||
p._shared_params = True
|
||||
if hasattr(self, 'norm2'):
|
||||
for p in self.norm2.parameters():
|
||||
p._shared_params = True
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
||||
|
||||
def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None,
|
||||
mixer_subset=None, mixer_kwargs=None):
|
||||
r"""Pass the input through the encoder layer.
|
||||
|
||||
Args:
|
||||
hidden_states: the sequence to the encoder layer (required).
|
||||
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
|
||||
mixer_subset: for cross-attention only. If not None, will take a subset of x
|
||||
before applying the query projection. Useful for e.g., ViT where we only care
|
||||
about the CLS token in the last layer.
|
||||
"""
|
||||
fused_add_norm_fn = (dropout_add_rms_norm if RMSNorm and isinstance(self.norm1, RMSNorm)
|
||||
else dropout_add_layer_norm)
|
||||
if self.prenorm:
|
||||
if not self.fused_dropout_add_ln:
|
||||
dropped = self.drop_path1(self.dropout1(hidden_states))
|
||||
residual = (dropped + residual) if residual is not None else dropped
|
||||
hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
|
||||
if self.residual_in_fp32:
|
||||
residual = residual.to(torch.float32)
|
||||
else:
|
||||
if self.drop_path1.p == 0 or not self.training:
|
||||
rowscale1 = None
|
||||
else:
|
||||
rowscale1 = self.drop_path1(torch.ones(
|
||||
hidden_states.shape[:-1], device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
)
|
||||
hidden_states, residual = fused_add_norm_fn(
|
||||
hidden_states, residual, self.norm1.weight, self.norm1.bias,
|
||||
self.dropout1.p if self.training else 0.0, self.norm1.eps,
|
||||
rowscale=rowscale1, prenorm=True, residual_in_fp32=self.residual_in_fp32
|
||||
)
|
||||
if mixer_kwargs is None:
|
||||
mixer_kwargs = {}
|
||||
if mixer_subset is not None:
|
||||
mixer_kwargs['mixer_subset'] = mixer_subset
|
||||
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
|
||||
if mixer_subset is not None:
|
||||
residual = residual[:, mixer_subset]
|
||||
if not isinstance(self.mlp, nn.Identity):
|
||||
if not self.fused_dropout_add_ln:
|
||||
dropped = self.drop_path2(self.dropout2(hidden_states))
|
||||
residual = (dropped + residual) if residual is not None else dropped
|
||||
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
||||
if self.residual_in_fp32:
|
||||
residual = residual.to(torch.float32)
|
||||
else:
|
||||
if self.drop_path2.p == 0 or not self.training:
|
||||
rowscale2 = None
|
||||
else:
|
||||
rowscale2 = self.drop_path2(torch.ones(
|
||||
hidden_states.shape[:-1], device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
)
|
||||
hidden_states, residual = fused_add_norm_fn(
|
||||
hidden_states, residual, self.norm2.weight, self.norm2.bias,
|
||||
self.dropout2.p if self.training else 0.0, self.norm2.eps,
|
||||
rowscale=rowscale2, prenorm=True, residual_in_fp32=self.residual_in_fp32
|
||||
)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
else:
|
||||
assert residual is None
|
||||
mixer_out = self.mixer(
|
||||
hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
|
||||
)
|
||||
if self.return_residual: # mixer out is actually a pair here
|
||||
mixer_out, hidden_states = mixer_out
|
||||
if not self.fused_dropout_add_ln:
|
||||
hidden_states = self.norm1((self.drop_path1(self.dropout1(mixer_out))
|
||||
+ hidden_states).to(dtype=self.norm1.weight.dtype))
|
||||
else:
|
||||
if self.drop_path1.p == 0 or not self.training:
|
||||
rowscale1 = None
|
||||
else:
|
||||
rowscale1 = self.drop_path1(torch.ones(
|
||||
mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype)
|
||||
)
|
||||
hidden_states = fused_add_norm_fn(
|
||||
mixer_out, hidden_states, self.norm1.weight, self.norm1.bias,
|
||||
self.dropout1.p if self.training else 0.0, self.norm1.eps,
|
||||
rowscale=rowscale1, prenorm=False
|
||||
)
|
||||
if not isinstance(self.mlp, nn.Identity):
|
||||
mlp_out = self.mlp(hidden_states)
|
||||
if self.return_residual: # mlp out is actually a pair here
|
||||
mlp_out, hidden_states = mlp_out
|
||||
if not self.fused_dropout_add_ln:
|
||||
hidden_states = self.norm2((self.drop_path2(self.dropout2(mlp_out))
|
||||
+ hidden_states).to(dtype=self.norm2.weight.dtype))
|
||||
else:
|
||||
if self.drop_path2.p == 0 or not self.training:
|
||||
rowscale2 = None
|
||||
else:
|
||||
rowscale2 = self.drop_path2(torch.ones(
|
||||
mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype)
|
||||
)
|
||||
hidden_states = fused_add_norm_fn(
|
||||
mlp_out, hidden_states, self.norm2.weight, self.norm2.bias,
|
||||
self.dropout2.p if self.training else 0.0, self.norm2.eps,
|
||||
rowscale=rowscale2, prenorm=False
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ParallelBlock(nn.Module):
|
||||
"""The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
|
||||
and PaLM.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
|
||||
dropout_cls=nn.Dropout, resid_dropout1=0., resid_dropout2=0.,
|
||||
tied_norm=False, fused_dropout_add_ln=False, residual_in_fp32=False,
|
||||
sequence_parallel=False, mark_shared_params=False):
|
||||
"""
|
||||
This Block has a slightly different structure compared to a regular
|
||||
prenorm Transformer block.
|
||||
The standard block is: LN -> MHA / MLP -> Dropout -> Add.
|
||||
[Ref: https://arxiv.org/abs/2002.04745]
|
||||
Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
|
||||
the hidden_states (output1 of the MHA / MLP) and the residual.
|
||||
This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
|
||||
The residual needs to be provided (except for the very first block).
|
||||
"""
|
||||
super().__init__()
|
||||
self.tied_norm = tied_norm
|
||||
self.fused_dropout_add_ln = fused_dropout_add_ln
|
||||
self.residual_in_fp32 = residual_in_fp32
|
||||
if mixer_cls is None:
|
||||
mixer_cls = partial(MHA, num_heads=dim // 64)
|
||||
if mlp_cls is None:
|
||||
mlp_cls = partial(Mlp, hidden_features=4 * dim)
|
||||
self.mixer = mixer_cls(dim)
|
||||
self.dropout1 = dropout_cls(resid_dropout1)
|
||||
self.norm1 = norm_cls(dim)
|
||||
self.mlp = mlp_cls(dim)
|
||||
self.dropout2 = dropout_cls(resid_dropout2)
|
||||
if not self.tied_norm:
|
||||
self.norm2 = norm_cls(dim)
|
||||
|
||||
if self.fused_dropout_add_ln:
|
||||
assert dropout_add_layer_norm_parallel_residual is not None, 'dropout_layer_norm is not installed'
|
||||
assert dropout_add_rms_norm_parallel_residual is not None, 'dropout_layer_norm is not installed'
|
||||
assert (isinstance(self.norm1, (nn.LayerNorm, RMSNorm))
|
||||
and isinstance(self.dropout1, nn.Dropout))
|
||||
|
||||
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
|
||||
# then the input to each worker in the tensor parallel group will be different.
|
||||
# This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
|
||||
# For now this is not an issue because we always use sequence_parallel=True during training
|
||||
# and only use sequence_parallel=False during inference.
|
||||
|
||||
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
|
||||
if sequence_parallel:
|
||||
for p in self.norm1.parameters():
|
||||
p._sequence_parallel = True
|
||||
if hasattr(self, 'norm2'):
|
||||
for p in self.norm2.parameters():
|
||||
p._sequence_parallel = True
|
||||
# Mark the norm parameters as "shared_params" so that we sync their values at init.
|
||||
if mark_shared_params:
|
||||
for p in self.norm1.parameters():
|
||||
p._shared_params = True
|
||||
if hasattr(self, 'norm2'):
|
||||
for p in self.norm2.parameters():
|
||||
p._shared_params = True
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
||||
|
||||
def forward(self, hidden_states1: Tensor, hidden_states2: Optional[Tensor] = None,
|
||||
residual: Optional[Tensor] = None, mixer_kwargs=None):
|
||||
r"""Pass the input through the encoder layer.
|
||||
|
||||
Args:
|
||||
hidden_states1: the output of the previous attention (mixer) or embedding layer.
|
||||
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
|
||||
residual.
|
||||
"""
|
||||
# TODO: Ideally we should only do the allgather / allreduce once for
|
||||
# the Linear to MLP & Attention
|
||||
fused_add_norm_fn = (dropout_add_rms_norm_parallel_residual
|
||||
if isinstance(self.norm1, RMSNorm)
|
||||
else dropout_add_layer_norm_parallel_residual)
|
||||
if not self.fused_dropout_add_ln:
|
||||
dropped1 = self.dropout1(hidden_states1)
|
||||
# For the very 1st block, we only want 1 dropout, not two different dropouts
|
||||
if hidden_states2 is not None:
|
||||
dropped2 = self.dropout2(hidden_states2)
|
||||
residual = ((residual + dropped1 + dropped2)
|
||||
if residual is not None else dropped1 + dropped2)
|
||||
else:
|
||||
residual = (residual + dropped1) if residual is not None else dropped1
|
||||
hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
|
||||
hidden_states2 = (self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
||||
if not self.tied_norm else hidden_states1)
|
||||
if self.residual_in_fp32:
|
||||
residual = residual.to(torch.float32)
|
||||
else:
|
||||
weight2, bias2 = ((self.norm2.weight, self.norm2.bias)
|
||||
if not self.tied_norm else (None, None))
|
||||
hidden_states1, hidden_states2, residual = fused_add_norm_fn(
|
||||
hidden_states1, hidden_states2, residual, self.norm1.weight, self.norm1.bias,
|
||||
weight2, bias2, self.dropout1.p if self.training else 0.0, self.norm1.eps,
|
||||
prenorm=True, residual_in_fp32=self.residual_in_fp32
|
||||
)
|
||||
if self.tied_norm:
|
||||
hidden_states2 = hidden_states1
|
||||
if mixer_kwargs is None:
|
||||
mixer_kwargs = {}
|
||||
hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
|
||||
hidden_states2 = self.mlp(hidden_states2)
|
||||
return hidden_states1, hidden_states2, residual
|
||||
183
pkgs/xformers/_flash_attn/modules/embedding.py
Normal file
183
pkgs/xformers/_flash_attn/modules/embedding.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# Copyright (c) 2022, Tri Dao.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from flash_attn.utils.distributed import reduce_scatter, all_reduce
|
||||
|
||||
|
||||
class GPT2Embeddings(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None,
|
||||
word_embed_proj_dim=None, device=None, dtype=None):
|
||||
"""
|
||||
If max_position_embeddings <= 0, there's no position embeddings
|
||||
If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
|
||||
the project up to embed_dim
|
||||
"""
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
if word_embed_proj_dim is None:
|
||||
self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
|
||||
**factory_kwargs)
|
||||
self.project_in = None
|
||||
else:
|
||||
self.word_embeddings = nn.Embedding(vocab_size, word_embed_proj_dim,
|
||||
padding_idx=padding_idx, **factory_kwargs)
|
||||
self.project_in = nn.Linear(word_embed_proj_dim, embed_dim, bias=False,
|
||||
**factory_kwargs)
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
if self.max_position_embeddings > 0:
|
||||
self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim,
|
||||
**factory_kwargs)
|
||||
|
||||
def forward(self, input_ids, position_ids=None):
|
||||
"""
|
||||
input_ids: (batch, seqlen)
|
||||
position_ids: (batch, seqlen)
|
||||
"""
|
||||
batch_size, seqlen = input_ids.shape
|
||||
embeddings = self.word_embeddings(input_ids)
|
||||
if self.project_in is not None:
|
||||
embeddings = self.project_in(embeddings)
|
||||
if self.max_position_embeddings > 0:
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
embeddings = embeddings + position_embeddings
|
||||
return embeddings
|
||||
|
||||
|
||||
class BertEmbeddings(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, vocab_size, max_position_embeddings, type_vocab_size,
|
||||
padding_idx=None, device=None, dtype=None):
|
||||
"""
|
||||
If max_position_embeddings <= 0, there's no position embeddings
|
||||
If type_vocab_size <= 0, there's no token type embeddings
|
||||
"""
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
|
||||
**factory_kwargs)
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
if self.max_position_embeddings > 0:
|
||||
self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim,
|
||||
**factory_kwargs)
|
||||
if self.type_vocab_size > 0:
|
||||
self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim,
|
||||
**factory_kwargs)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None):
|
||||
"""
|
||||
input_ids: (batch, seqlen)
|
||||
position_ids: (batch, seqlen)
|
||||
token_type_ids: (batch, seqlen)
|
||||
"""
|
||||
batch_size, seqlen = input_ids.shape
|
||||
embeddings = self.word_embeddings(input_ids)
|
||||
if self.max_position_embeddings > 0:
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
embeddings = embeddings + position_embeddings
|
||||
if self.type_vocab_size > 0:
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
embeddings = embeddings + token_type_embeddings
|
||||
return embeddings
|
||||
|
||||
|
||||
class VocabParallelEmbedding(nn.Embedding):
|
||||
|
||||
def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):
|
||||
self.process_group = process_group
|
||||
if process_group is not None:
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
if num_embeddings % world_size != 0:
|
||||
raise ValueError(f'num_embeddings ({num_embeddings}) must be divisible by '
|
||||
f'world_size ({world_size})')
|
||||
if world_size > 1 and padding_idx is not None:
|
||||
raise RuntimeError('ParallelEmbedding does not support padding_idx')
|
||||
else:
|
||||
world_size = 1
|
||||
super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
if self.process_group is None:
|
||||
return super().forward(input)
|
||||
else:
|
||||
rank = torch.distributed.get_rank(self.process_group)
|
||||
vocab_size = self.num_embeddings
|
||||
vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
|
||||
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
||||
input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
|
||||
input = input - vocab_start_index
|
||||
input[input_ids_mask] = 0
|
||||
embeddings = super().forward(input)
|
||||
embeddings[input_ids_mask] = 0.0
|
||||
return embeddings
|
||||
|
||||
|
||||
class ColumnParallelEmbedding(nn.Embedding):
|
||||
|
||||
def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):
|
||||
self.process_group = process_group
|
||||
if process_group is not None:
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
if embedding_dim % world_size != 0:
|
||||
raise ValueError(f'embedding_dim ({embedding_dim}) must be divisible by '
|
||||
f'world_size ({world_size})')
|
||||
else:
|
||||
world_size = 1
|
||||
super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
|
||||
|
||||
|
||||
class ParallelGPT2Embeddings(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, vocab_size, max_position_embeddings, process_group,
|
||||
padding_idx=None, sequence_parallel=True, device=None, dtype=None):
|
||||
"""
|
||||
If max_position_embeddings <= 0, there's no position embeddings
|
||||
"""
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = sequence_parallel
|
||||
self.word_embeddings = VocabParallelEmbedding(
|
||||
vocab_size, embed_dim, padding_idx=padding_idx, process_group=process_group,
|
||||
**factory_kwargs
|
||||
)
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
if self.max_position_embeddings > 0:
|
||||
self.position_embeddings = ColumnParallelEmbedding(
|
||||
max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs
|
||||
)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
|
||||
"""
|
||||
input_ids: (batch, seqlen)
|
||||
position_ids: (batch, seqlen)
|
||||
"""
|
||||
batch_size, seqlen = input_ids.shape
|
||||
world_size = torch.distributed.get_world_size(self.process_group)
|
||||
embeddings = self.word_embeddings(input_ids)
|
||||
if self.max_position_embeddings > 0:
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
if world_size <= 1:
|
||||
embeddings = embeddings + position_embeddings
|
||||
else:
|
||||
partition_dim = self.position_embeddings.embedding_dim
|
||||
rank = torch.distributed.get_rank(self.process_group)
|
||||
embeddings[..., rank * partition_dim:(rank + 1) * partition_dim] += position_embeddings
|
||||
if combine_batch_seqlen_dim:
|
||||
embeddings = rearrange(embeddings, 'b s d -> (b s) d')
|
||||
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
||||
return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
|
||||
711
pkgs/xformers/_flash_attn/modules/mha.py
Normal file
711
pkgs/xformers/_flash_attn/modules/mha.py
Normal file
@@ -0,0 +1,711 @@
|
||||
# Copyright (c) 2022, Tri Dao.
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func
|
||||
from flash_attn import flash_attn_qkvpacked_func, flash_attn_kvpacked_func
|
||||
except ImportError:
|
||||
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
|
||||
flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.fused_dense import FusedDense, ColumnParallelLinear, RowParallelLinear
|
||||
except ImportError:
|
||||
FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
|
||||
|
||||
try:
|
||||
from flash_attn.layers.rotary import RotaryEmbedding
|
||||
except ImportError:
|
||||
RotaryEmbedding = None
|
||||
|
||||
try:
|
||||
import ft_attention
|
||||
except ImportError:
|
||||
ft_attention = None
|
||||
|
||||
|
||||
class FlashSelfAttention(nn.Module):
|
||||
"""Implement the scaled dot product attention with softmax.
|
||||
Arguments
|
||||
---------
|
||||
softmax_scale: The temperature to use for the softmax attention.
|
||||
(default: 1/sqrt(d_keys) where d_keys is computed at
|
||||
runtime)
|
||||
attention_dropout: The dropout rate to apply to the attention
|
||||
(default: 0.0)
|
||||
"""
|
||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
||||
super().__init__()
|
||||
assert flash_attn_varlen_qkvpacked_func is not None, 'FlashAttention is not installed'
|
||||
assert flash_attn_qkvpacked_func is not None, 'FlashAttention is not installed'
|
||||
self.causal = causal
|
||||
self.softmax_scale = softmax_scale
|
||||
self.drop = nn.Dropout(attention_dropout)
|
||||
|
||||
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
|
||||
"""Implements the multihead softmax attention.
|
||||
Arguments
|
||||
---------
|
||||
qkv: The tensor containing the query, key, and value.
|
||||
If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
|
||||
If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
|
||||
(total, 3, H, D), where total is the sum of the sequence lengths in the batch.
|
||||
causal: if passed, will override self.causal
|
||||
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into qkv.
|
||||
max_seqlen: int. Maximum sequence length in the batch.
|
||||
Returns:
|
||||
--------
|
||||
out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
|
||||
else (B, S, H, D).
|
||||
"""
|
||||
assert qkv.dtype in [torch.float16, torch.bfloat16]
|
||||
assert qkv.is_cuda
|
||||
causal = self.causal if causal is None else causal
|
||||
unpadded = cu_seqlens is not None
|
||||
if unpadded:
|
||||
assert cu_seqlens.dtype == torch.int32
|
||||
assert max_seqlen is not None
|
||||
assert isinstance(max_seqlen, int)
|
||||
return flash_attn_varlen_qkvpacked_func(
|
||||
qkv, cu_seqlens, max_seqlen, self.drop.p if self.training else 0.0,
|
||||
softmax_scale=self.softmax_scale, causal=causal
|
||||
)
|
||||
else:
|
||||
return flash_attn_qkvpacked_func(qkv, self.drop.p if self.training else 0.0,
|
||||
softmax_scale=self.softmax_scale, causal=causal)
|
||||
|
||||
|
||||
class FlashCrossAttention(nn.Module):
|
||||
"""Implement the scaled dot product attention with softmax.
|
||||
Arguments
|
||||
---------
|
||||
softmax_scale: The temperature to use for the softmax attention.
|
||||
(default: 1/sqrt(d_keys) where d_keys is computed at
|
||||
runtime)
|
||||
attention_dropout: The dropout rate to apply to the attention
|
||||
(default: 0.0)
|
||||
"""
|
||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
||||
super().__init__()
|
||||
assert flash_attn_varlen_kvpacked_func is not None, 'FlashAttention is not installed'
|
||||
assert flash_attn_kvpacked_func is not None, 'FlashAttention is not installed'
|
||||
self.causal = causal
|
||||
self.softmax_scale = softmax_scale
|
||||
self.drop = nn.Dropout(attention_dropout)
|
||||
|
||||
def forward(self, q, kv, causal=None, cu_seqlens=None, max_seqlen=None,
|
||||
cu_seqlens_k=None, max_seqlen_k=None):
|
||||
"""Implements the multihead softmax attention.
|
||||
Arguments
|
||||
---------
|
||||
q: The tensor containing the query. (B, Sq, H, D)
|
||||
kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
|
||||
causal: if passed, will override self.causal
|
||||
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into q.
|
||||
max_seqlen: int. Maximum sequence length in the batch of q.
|
||||
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into kv.
|
||||
max_seqlen_k: int. Maximum sequence length in the batch of k and v.
|
||||
"""
|
||||
assert q.dtype in [torch.float16, torch.bfloat16]
|
||||
assert q.is_cuda and kv.is_cuda
|
||||
causal = self.causal if causal is None else causal
|
||||
unpadded = cu_seqlens is not None
|
||||
if unpadded:
|
||||
assert cu_seqlens.dtype == torch.int32
|
||||
assert max_seqlen is not None
|
||||
assert isinstance(max_seqlen, int)
|
||||
assert cu_seqlens_k is not None
|
||||
assert cu_seqlens_k.dtype == torch.int32
|
||||
assert max_seqlen_k is not None
|
||||
assert isinstance(max_seqlen, int)
|
||||
return flash_attn_varlen_kvpacked_func(
|
||||
q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k,
|
||||
self.drop.p if self.training else 0.0,
|
||||
softmax_scale=self.softmax_scale, causal=causal
|
||||
)
|
||||
else:
|
||||
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
||||
seqlen_k = kv.shape[1]
|
||||
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
|
||||
return flash_attn_kvpacked_func(q, kv, self.drop.p if self.training else 0.0,
|
||||
causal=causal, softmax_scale=self.softmax_scale)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
"""Implement the scaled dot product attention with softmax.
|
||||
Arguments
|
||||
---------
|
||||
softmax_scale: The temperature to use for the softmax attention.
|
||||
(default: 1/sqrt(d_keys) where d_keys is computed at
|
||||
runtime)
|
||||
attention_dropout: The dropout rate to apply to the attention
|
||||
(default: 0.0)
|
||||
"""
|
||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
||||
super().__init__()
|
||||
self.causal = causal
|
||||
self.softmax_scale = softmax_scale
|
||||
self.drop = nn.Dropout(attention_dropout)
|
||||
|
||||
def forward(self, qkv, causal=None, key_padding_mask=None):
|
||||
"""Implements the multihead softmax attention.
|
||||
Arguments
|
||||
---------
|
||||
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
|
||||
causal: if passed, will override self.causal
|
||||
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
|
||||
False means to mask out. (B, S)
|
||||
"""
|
||||
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
||||
causal = self.causal if causal is None else causal
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
||||
scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
|
||||
if key_padding_mask is not None:
|
||||
padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype,
|
||||
device=scores.device)
|
||||
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
||||
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
||||
scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
|
||||
if causal:
|
||||
# "triu_tril_cuda_template" not implemented for 'BFloat16'
|
||||
# So we have to construct the mask in float
|
||||
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
|
||||
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
||||
scores = scores + causal_mask.to(dtype=scores.dtype)
|
||||
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
||||
attention_drop = self.drop(attention)
|
||||
output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
|
||||
return output
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
"""Implement the scaled dot product attention with softmax.
|
||||
Arguments
|
||||
---------
|
||||
softmax_scale: The temperature to use for the softmax attention.
|
||||
(default: 1/sqrt(d_keys) where d_keys is computed at
|
||||
runtime)
|
||||
attention_dropout: The dropout rate to apply to the attention
|
||||
(default: 0.0)
|
||||
"""
|
||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
||||
super().__init__()
|
||||
self.causal = causal
|
||||
self.softmax_scale = softmax_scale
|
||||
self.drop = nn.Dropout(attention_dropout)
|
||||
|
||||
def forward(self, q, kv, causal=None, key_padding_mask=None):
|
||||
"""Implements the multihead softmax attention.
|
||||
Arguments
|
||||
---------
|
||||
q: The tensor containing the query. (B, Sq, H, D)
|
||||
kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
|
||||
causal: if passed, will override self.causal
|
||||
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
|
||||
False means to mask out. (B, Sk)
|
||||
"""
|
||||
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
||||
causal = self.causal if causal is None else causal
|
||||
seqlen_k = kv.shape[1]
|
||||
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
|
||||
if kv.shape[3] != q.shape[2]: # MQA/GQA
|
||||
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
|
||||
k, v = kv.unbind(dim=2)
|
||||
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
||||
scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
|
||||
if key_padding_mask is not None:
|
||||
padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype,
|
||||
device=scores.device)
|
||||
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
||||
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
||||
scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
|
||||
if causal:
|
||||
# "triu_tril_cuda_template" not implemented for 'BFloat16'
|
||||
# So we have to construct the mask in float
|
||||
causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0,
|
||||
device=scores.device), 1)
|
||||
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
||||
scores = scores + causal_mask.to(dtype=scores.dtype)
|
||||
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
||||
attention_drop = self.drop(attention)
|
||||
output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
|
||||
return output
|
||||
|
||||
|
||||
class LinearResidual(nn.Linear):
|
||||
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.
|
||||
"""
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(input), input
|
||||
|
||||
|
||||
def _update_kv_cache(kv, inference_params, layer_idx):
|
||||
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
|
||||
"""
|
||||
# Pre-allocate memory for key-values for inference.
|
||||
num_heads, head_dim = kv.shape[-2:]
|
||||
if layer_idx not in inference_params.key_value_memory_dict:
|
||||
kv_cache = torch.empty(
|
||||
inference_params.max_batch_size, inference_params.max_sequence_len, 2,
|
||||
num_heads, head_dim, dtype=kv.dtype, device=kv.device
|
||||
)
|
||||
inference_params.key_value_memory_dict[layer_idx] = kv_cache
|
||||
else:
|
||||
if not inference_params.fused_ft_kernel:
|
||||
kv_cache = inference_params.key_value_memory_dict[layer_idx]
|
||||
else:
|
||||
# For FT, k_cache has shape (b, h, headdim / packsize, s, packsize)
|
||||
# where packsize = 4 if fp32, 8 if fp16 or bf16.
|
||||
# v_cache has shape (b, h, s, headdim)
|
||||
k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
|
||||
kv_cache = None
|
||||
# Adjust key and value for inference
|
||||
batch_start = inference_params.batch_size_offset
|
||||
batch_end = batch_start + kv.shape[0]
|
||||
sequence_start = inference_params.sequence_len_offset
|
||||
sequence_end = sequence_start + kv.shape[1]
|
||||
assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
|
||||
assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
|
||||
# Copy key and values.
|
||||
if not inference_params.fused_ft_kernel:
|
||||
assert kv_cache is not None
|
||||
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
||||
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
|
||||
return kv
|
||||
else:
|
||||
assert inference_params.sequence_len_offset == 0
|
||||
# FT kernel requires different layouts for the k_cache and v_cache.
|
||||
assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
packsize = 4 if kv.dtype == torch.float32 else 8
|
||||
if kv_cache is not None:
|
||||
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
||||
k_cache = rearrange(kv_cache[:, :, 0], 'b s h (d packsize) -> b h d s packsize',
|
||||
packsize=packsize).contiguous()
|
||||
v_cache = rearrange(kv_cache[:, :, 1], 'b s h d -> b h s d').contiguous()
|
||||
inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache)
|
||||
else:
|
||||
k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
|
||||
kv[:, :, 0], 'b s h (d packsize) -> b h d s packsize', packsize=packsize
|
||||
)
|
||||
v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(
|
||||
kv[:, :, 1], 'b s h d -> b h s d'
|
||||
)
|
||||
return kv
|
||||
|
||||
|
||||
def _apply_rotary_single_query_attention(qkv, inference_params, layer_idx, rotary_emb_dim,
|
||||
rotary_emb_base, kv=None, rotary_emb_interleaved=False):
|
||||
"""
|
||||
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
|
||||
q of shape (batch_size, 1, nheads, head_dim)
|
||||
kv: (batch_size, 1, 2, nheads_kv, head_dim)
|
||||
"""
|
||||
assert inference_params.fused_ft_kernel
|
||||
assert ft_attention is not None
|
||||
if kv is None:
|
||||
q, k, v = rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1)
|
||||
else:
|
||||
q = rearrange(qkv, 'b 1 h d -> b h d')
|
||||
k, v = rearrange(kv, 'b 1 two h d -> b two h d').unbind(dim=1)
|
||||
batch_start = inference_params.batch_size_offset
|
||||
batch_end = batch_start + q.shape[0]
|
||||
k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
|
||||
lengths_per_sample = (inference_params.lengths_per_sample[batch_start:batch_end]
|
||||
if inference_params.lengths_per_sample is not None else None)
|
||||
context = ft_attention.single_query_attention(
|
||||
q, k, v,
|
||||
k_cache[batch_start:batch_end],
|
||||
v_cache[batch_start:batch_end],
|
||||
lengths_per_sample,
|
||||
None, # rotary_cos_
|
||||
None, # rotary_sin_
|
||||
None, # nnz_head_idx
|
||||
inference_params.sequence_len_offset,
|
||||
rotary_emb_dim, rotary_emb_base,
|
||||
not rotary_emb_interleaved # neox_rotary_style
|
||||
)
|
||||
return rearrange(context, 'b h d -> b 1 h d')
|
||||
|
||||
|
||||
class MHA(nn.Module):
|
||||
"""Multi-head self-attention and cross-attention
|
||||
"""
|
||||
|
||||
def __init__(self, embed_dim, num_heads, num_heads_kv=None, cross_attn=False,
|
||||
qkv_proj_bias=True, out_proj_bias=True,
|
||||
dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, dwconv=False,
|
||||
rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None,
|
||||
rotary_emb_interleaved=False, fused_bias_fc=False, use_flash_attn=False,
|
||||
return_residual=False, checkpointing=False, device=None, dtype=None) -> None:
|
||||
"""
|
||||
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
|
||||
return_residual: whether to return the input x along with the output. This is for
|
||||
performance reason: for post-norm architecture, returning the input allows us
|
||||
to fuse the backward of nn.Linear with the residual connection.
|
||||
"""
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.cross_attn = cross_attn
|
||||
self.causal = causal
|
||||
self.layer_idx = layer_idx
|
||||
self.dwconv = dwconv
|
||||
self.rotary_emb_dim = rotary_emb_dim
|
||||
self.use_flash_attn = use_flash_attn
|
||||
self.return_residual = return_residual
|
||||
self.checkpointing = checkpointing
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
||||
assert self.num_heads % self.num_heads_kv == 0, "num_heads must be divisible by num_heads_kv"
|
||||
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
||||
self.head_dim = self.embed_dim // num_heads
|
||||
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
||||
kv_dim = 2 * self.head_dim * self.num_heads_kv
|
||||
|
||||
if self.rotary_emb_dim > 0:
|
||||
assert not cross_attn, 'MHA with rotary embedding does not support cross-attention yet'
|
||||
assert RotaryEmbedding is not None, 'rotary_emb is not installed'
|
||||
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, base=rotary_emb_base,
|
||||
scale_base=rotary_emb_scale_base,
|
||||
interleaved=rotary_emb_interleaved, device=device)
|
||||
|
||||
if fused_bias_fc and FusedDense is None:
|
||||
raise ImportError('fused_dense is not installed')
|
||||
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
||||
linear_resid_cls = (LinearResidual if not fused_bias_fc
|
||||
else partial(FusedDense, return_residual=True))
|
||||
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
|
||||
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
|
||||
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
|
||||
if not self.cross_attn:
|
||||
self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
|
||||
else:
|
||||
self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
|
||||
self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
|
||||
if self.dwconv:
|
||||
if self.num_heads_kv == self.num_heads:
|
||||
self.dwconv_qkv = nn.Conv1d(qkv_dim, qkv_dim, kernel_size=3, padding=2,
|
||||
groups=qkv_dim)
|
||||
else:
|
||||
self.dwconv_q = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, padding=2,
|
||||
groups=embed_dim)
|
||||
self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2,
|
||||
groups=kv_dim)
|
||||
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
|
||||
attention_dropout=dropout)
|
||||
self.inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale,
|
||||
attention_dropout=dropout)
|
||||
self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, fused_ft_kernel=True):
|
||||
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
||||
device = self.out_proj.weight.device
|
||||
if not fused_ft_kernel:
|
||||
return torch.empty(batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim,
|
||||
dtype=dtype, device=device)
|
||||
else:
|
||||
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
packsize = 4 if dtype == torch.float32 else 8
|
||||
assert self.head_dim % packsize == 0
|
||||
k_cache = torch.empty(batch_size, self.num_heads_kv, self.head_dim // packsize,
|
||||
max_seqlen, packsize, dtype=dtype, device=device)
|
||||
v_cache = torch.empty(batch_size, self.num_heads_kv, max_seqlen, self.head_dim,
|
||||
dtype=dtype, device=device)
|
||||
return k_cache, v_cache
|
||||
|
||||
def _update_kv_cache(self, kv, inference_params):
|
||||
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
|
||||
"""
|
||||
assert not self.dwconv, 'Generation does not support dwconv yet'
|
||||
assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor'
|
||||
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
||||
|
||||
def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None):
|
||||
"""
|
||||
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
|
||||
q of shape (batch_size, 1, nheads, head_dim)
|
||||
kv: (batch_size, 1, 2, nheads_kv, head_dim)
|
||||
"""
|
||||
rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
|
||||
return _apply_rotary_single_query_attention(
|
||||
qkv, inference_params, self.layer_idx, self.rotary_emb_dim, rotary_emb_base, kv=kv,
|
||||
rotary_emb_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
|
||||
)
|
||||
|
||||
def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None,
|
||||
mixer_subset=None, inference_params=None, **kwargs):
|
||||
"""
|
||||
Arguments:
|
||||
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
|
||||
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
|
||||
is the is the sum of the sequence lengths in the batch.
|
||||
x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
|
||||
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into x. Only applicable when using
|
||||
FlashAttention.
|
||||
max_seqlen: int. Maximum sequence length in the batch.
|
||||
key_padding_mask: boolean mask, True means to keep, False means to mask out.
|
||||
(batch, seqlen). Only applicable when not using FlashAttention.
|
||||
mixer_subset: for cross-attention only. If not None, will take a subset of x
|
||||
before applying the query projection. Useful for e.g., ViT where we only care
|
||||
about the CLS token in the last layer.
|
||||
inference_params: for generation. Adapted from Megatron-LM (and Apex)
|
||||
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
||||
"""
|
||||
if cu_seqlens is not None:
|
||||
assert max_seqlen is not None
|
||||
assert key_padding_mask is None
|
||||
assert self.use_flash_attn
|
||||
assert not self.dwconv
|
||||
assert self.rotary_emb_dim == 0
|
||||
if key_padding_mask is not None:
|
||||
assert cu_seqlens is None
|
||||
assert max_seqlen is None
|
||||
assert not self.use_flash_attn
|
||||
if inference_params is not None:
|
||||
assert key_padding_mask is None
|
||||
assert cu_seqlens is None and max_seqlen is None
|
||||
assert not self.dwconv
|
||||
|
||||
kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, **kwargs}
|
||||
if self.use_flash_attn else {'key_padding_mask': key_padding_mask, **kwargs})
|
||||
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
|
||||
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
||||
assert x_kv is None and mixer_subset is None
|
||||
if not self.return_residual:
|
||||
qkv = self.Wqkv(x)
|
||||
else:
|
||||
qkv, x = self.Wqkv(x)
|
||||
if self.dwconv:
|
||||
qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2],
|
||||
'b d s -> b s d').contiguous()
|
||||
qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, d=self.head_dim)
|
||||
if (inference_params is None or inference_params.sequence_len_offset == 0
|
||||
or not inference_params.fused_ft_kernel):
|
||||
if self.rotary_emb_dim > 0:
|
||||
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
|
||||
if inference_params is None:
|
||||
if not self.checkpointing:
|
||||
context = self.inner_attn(qkv, **kwargs)
|
||||
else:
|
||||
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv,
|
||||
**kwargs)
|
||||
else:
|
||||
q = qkv[:, :, 0]
|
||||
kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
|
||||
# If we're processing the prompt, causal=None (use self.causal).
|
||||
# If we're decoding, then causal=False.
|
||||
causal = None if inference_params.sequence_len_offset == 0 else False
|
||||
context = self.inner_cross_attn(q, kv, causal=causal)
|
||||
else:
|
||||
context = self._apply_rotary_single_query_attention(qkv, inference_params)
|
||||
else:
|
||||
if self.cross_attn:
|
||||
if not self.return_residual:
|
||||
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
|
||||
kv = self.Wkv(x_kv if x_kv is not None else x)
|
||||
else:
|
||||
if x_kv is not None:
|
||||
kv, x_kv = self.Wkv(x_kv)
|
||||
else:
|
||||
kv, x = self.Wkv(x)
|
||||
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
|
||||
else:
|
||||
assert self.num_heads_kv != self.num_heads
|
||||
if not self.return_residual:
|
||||
qkv = self.Wqkv(x)
|
||||
else:
|
||||
qkv, x = self.Wqkv(x)
|
||||
q = qkv[..., :self.num_heads * self.head_dim]
|
||||
kv = qkv[..., self.num_heads * self.head_dim:]
|
||||
q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
|
||||
kv = rearrange(kv, '... (two hkv d) -> ... two hkv d', two=2, d=self.head_dim)
|
||||
if self.dwconv:
|
||||
q = rearrange(self.dwconv_q(rearrange(q, 'b s d -> b d s'))[..., :-2],
|
||||
'b d s -> b s d').contiguous()
|
||||
kv = rearrange(self.dwconv_kv(rearrange(kv, 'b s d -> b d s'))[..., :-2],
|
||||
'b d s -> b s d').contiguous()
|
||||
if (inference_params is None or inference_params.sequence_len_offset == 0
|
||||
or not inference_params.fused_ft_kernel):
|
||||
if self.rotary_emb_dim > 0:
|
||||
q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset)
|
||||
if inference_params is None:
|
||||
if not self.checkpointing:
|
||||
context = self.inner_cross_attn(q, kv, **kwargs)
|
||||
else:
|
||||
context = torch.utils.checkpoint.checkpoint(self.inner_cross_attn, q, kv,
|
||||
**kwargs)
|
||||
else:
|
||||
kv = self._update_kv_cache(kv, inference_params)
|
||||
# If we're processing the prompt, causal=None (use self.causal).
|
||||
# If we're decoding, then causal=False.
|
||||
causal = None if inference_params.sequence_len_offset == 0 else False
|
||||
context = self.inner_cross_attn(q, kv, causal=causal)
|
||||
else:
|
||||
context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
|
||||
out = self.out_proj(rearrange(context, '... h d -> ... (h d)'))
|
||||
return out if not self.return_residual else (out, x)
|
||||
|
||||
|
||||
class ParallelMHA(nn.Module):
|
||||
"""Multi-head self-attention and cross-attention
|
||||
"""
|
||||
|
||||
def __init__(self, embed_dim, num_heads, process_group, num_heads_kv=None,
|
||||
qkv_proj_bias=True, out_proj_bias=True,
|
||||
dropout=0.0, softmax_scale=None, causal=False, layer_idx=None,
|
||||
rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None,
|
||||
rotary_emb_interleaved=False, use_flash_attn=False, checkpointing=False,
|
||||
sequence_parallel=True, device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.causal = causal
|
||||
self.layer_idx = layer_idx
|
||||
self.rotary_emb_dim = rotary_emb_dim
|
||||
self.use_flash_attn = use_flash_attn
|
||||
self.checkpointing = checkpointing
|
||||
self.process_group = process_group
|
||||
self.world_size = process_group.size() if process_group is not None else 1
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
||||
self.num_heads_per_rank = num_heads // self.world_size
|
||||
self.num_heads_kv_per_rank = self.num_heads_kv // self.world_size
|
||||
assert self.num_heads % self.num_heads_kv == 0, "num_heads must be divisible by num_heads_kv"
|
||||
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
||||
assert self.num_heads_kv % self.world_size == 0, "num_heads_kv must be divisible by world_size"
|
||||
self.head_dim = self.embed_dim // num_heads
|
||||
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
||||
kv_dim = 2 * self.head_dim * self.num_heads_kv
|
||||
|
||||
if self.rotary_emb_dim > 0:
|
||||
assert RotaryEmbedding is not None, 'rotary_emb is not installed'
|
||||
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, base=rotary_emb_base,
|
||||
scale_base=rotary_emb_scale_base,
|
||||
interleaved=rotary_emb_interleaved, device=device)
|
||||
|
||||
if ColumnParallelLinear is None or RowParallelLinear is None:
|
||||
raise ImportError('fused_dense is not installed')
|
||||
self.Wqkv = ColumnParallelLinear(embed_dim, qkv_dim, process_group,
|
||||
bias=qkv_proj_bias,
|
||||
sequence_parallel=sequence_parallel, **factory_kwargs)
|
||||
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
|
||||
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
|
||||
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
|
||||
attention_dropout=dropout)
|
||||
self.inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale,
|
||||
attention_dropout=dropout)
|
||||
self.out_proj = RowParallelLinear(embed_dim, embed_dim, process_group,
|
||||
bias=out_proj_bias,
|
||||
sequence_parallel=sequence_parallel, **factory_kwargs)
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, fused_ft_kernel=True):
|
||||
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
||||
device = self.out_proj.weight.device
|
||||
if not fused_ft_kernel:
|
||||
return torch.empty(batch_size, max_seqlen, 2, self.num_heads_kv_per_rank,
|
||||
self.head_dim, dtype=dtype, device=device)
|
||||
else:
|
||||
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
packsize = 4 if dtype == torch.float32 else 8
|
||||
assert self.head_dim % packsize == 0
|
||||
k_cache = torch.empty(batch_size, self.num_heads_kv_per_rank,
|
||||
self.head_dim // packsize,
|
||||
max_seqlen, packsize, dtype=dtype, device=device)
|
||||
v_cache = torch.empty(batch_size, self.num_heads_kv_per_rank, max_seqlen,
|
||||
self.head_dim, dtype=dtype, device=device)
|
||||
return k_cache, v_cache
|
||||
|
||||
def _update_kv_cache(self, kv, inference_params):
|
||||
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
|
||||
"""
|
||||
assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor'
|
||||
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
||||
|
||||
def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None):
|
||||
"""
|
||||
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
|
||||
q of shape (batch_size, 1, nheads, head_dim)
|
||||
kv: (batch_size, 1, 2, nheads_kv, head_dim)
|
||||
"""
|
||||
rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
|
||||
return _apply_rotary_single_query_attention(
|
||||
qkv, inference_params, self.layer_idx, self.rotary_emb_dim, rotary_emb_base, kv=kv,
|
||||
rotary_emb_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
|
||||
)
|
||||
|
||||
def forward(self, x, seqlen=None, inference_params=None, **kwargs):
|
||||
"""
|
||||
Arguments:
|
||||
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
|
||||
If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
|
||||
split x during sequence parallel, we split the batch * seqlen dimension
|
||||
(in case batch is small).
|
||||
"""
|
||||
qkv = self.Wqkv(x)
|
||||
if seqlen is not None:
|
||||
qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
|
||||
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
|
||||
if self.num_heads_kv == self.num_heads:
|
||||
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, d=self.head_dim)
|
||||
if (inference_params is None or inference_params.sequence_len_offset == 0
|
||||
or not inference_params.fused_ft_kernel):
|
||||
if self.rotary_emb_dim > 0:
|
||||
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
|
||||
if inference_params is None:
|
||||
if not self.checkpointing:
|
||||
context = self.inner_attn(qkv, **kwargs)
|
||||
else:
|
||||
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
|
||||
else:
|
||||
q = qkv[:, :, 0]
|
||||
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
|
||||
# If we're processing the prompt, causal=None (use self.causal).
|
||||
# If we're decoding, then causal=False.
|
||||
causal = None if inference_params.sequence_len_offset == 0 else False
|
||||
context = self.inner_cross_attn(q, kv, causal=causal)
|
||||
else:
|
||||
context = self._apply_rotary_single_query_attention(qkv, inference_params)
|
||||
else:
|
||||
q = rearrange(qkv[..., :self.num_heads_per_rank * self.head_dim],
|
||||
"... (h d) -> ... h d", d=self.head_dim)
|
||||
kv = rearrange(qkv[..., self.num_heads_per_rank * self.head_dim:],
|
||||
"... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
|
||||
if (inference_params is None or inference_params.sequence_len_offset == 0
|
||||
or not inference_params.fused_ft_kernel):
|
||||
if self.rotary_emb_dim > 0:
|
||||
q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset)
|
||||
if inference_params is None:
|
||||
if not self.checkpointing:
|
||||
context = self.inner_cross_attn(q, kv, **kwargs)
|
||||
else:
|
||||
context = torch.utils.checkpoint.checkpoint(self.inner_cross_attn, q, kv,
|
||||
**kwargs)
|
||||
else:
|
||||
kv = self._update_kv_cache(kv, inference_params)
|
||||
# If we're processing the prompt, causal=None (use self.causal).
|
||||
# If we're decoding, then causal=False.
|
||||
causal = None if inference_params.sequence_len_offset == 0 else False
|
||||
context = self.inner_cross_attn(q, kv, causal=causal)
|
||||
else:
|
||||
context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
|
||||
context = rearrange(context, 'b s h d -> b s (h d)')
|
||||
if seqlen is not None:
|
||||
context = rearrange(context, 'b s d -> (b s) d')
|
||||
out = self.out_proj(context)
|
||||
return out
|
||||
86
pkgs/xformers/_flash_attn/modules/mlp.py
Normal file
86
pkgs/xformers/_flash_attn/modules/mlp.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# Copyright (c) 2022, Tri Dao.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
try:
|
||||
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
|
||||
except ImportError:
|
||||
ColumnParallelLinear, RowParallelLinear = None, None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
|
||||
except ImportError:
|
||||
FusedMLP, ParallelFusedMLP = None, None
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
|
||||
bias1=True, bias2=True, return_residual=False, device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features * 4
|
||||
self.return_residual = return_residual
|
||||
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
|
||||
self.activation = activation
|
||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.fc1(x)
|
||||
y = self.activation(y)
|
||||
y = self.fc2(y)
|
||||
return y if not self.return_residual else (y, x)
|
||||
|
||||
|
||||
class ParallelMLP(nn.Module):
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
|
||||
process_group: ProcessGroup = None, sequence_parallel=True,
|
||||
bias1=True, bias2=True, device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
assert ColumnParallelLinear is not None, "Need to install fused_dense"
|
||||
assert RowParallelLinear is not None, "Need to install fused_dense"
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features * 4
|
||||
self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group, bias=bias1,
|
||||
sequence_parallel=sequence_parallel, **factory_kwargs)
|
||||
self.activation = activation
|
||||
self.fc2 = RowParallelLinear(hidden_features, out_features, process_group, bias=bias2,
|
||||
sequence_parallel=sequence_parallel, **factory_kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.fc1(x)
|
||||
y = self.activation(y)
|
||||
y = self.fc2(y)
|
||||
return y
|
||||
|
||||
|
||||
class GatedMlp(nn.Module):
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.sigmoid,
|
||||
bias1=True, bias2=True, multiple_of=256, return_residual=False,
|
||||
device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or int(8 * in_features / 3)
|
||||
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
||||
self.return_residual = return_residual
|
||||
self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
|
||||
self.activation = activation
|
||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias1, **factory_kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.fc1(x)
|
||||
if self.activation == F.sigmoid: # Special case for GLU
|
||||
y = F.glu(y, dim=-1)
|
||||
else:
|
||||
y, gate = y.chunk(2, dim=-1)
|
||||
y = y * self.activation(gate)
|
||||
y = self.fc2(y)
|
||||
return y if not self.return_residual else (y, x)
|
||||
0
pkgs/xformers/_flash_attn/ops/__init__.py
Normal file
0
pkgs/xformers/_flash_attn/ops/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
99
pkgs/xformers/_flash_attn/ops/activations.py
Normal file
99
pkgs/xformers/_flash_attn/ops/activations.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# Copied from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/model/layers/activations.py
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# 1/sqrt(2*pi)-> 0.3989423
|
||||
# 1/sqrt(2) -> 0.70710678
|
||||
# sqrt(2/pi) -> 0.79788456
|
||||
|
||||
# this function is tanh approximation of gelu
|
||||
# actual gelu is:
|
||||
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
|
||||
@torch.jit.script
|
||||
def bias_gelu(y, bias):
|
||||
x = bias + y
|
||||
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype)
|
||||
|
||||
# gradient of tanh approximation of gelu
|
||||
# gradient of actual gelu is:
|
||||
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
|
||||
@torch.jit.script
|
||||
def bias_gelu_back(g, y, bias):
|
||||
"""Assume that y has shape (B, D) and bias has shape (D)
|
||||
"""
|
||||
x = bias + y
|
||||
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
||||
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
|
||||
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
|
||||
grad_y = ff * g
|
||||
return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)
|
||||
|
||||
|
||||
class GeLUFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# bias is an optional argument
|
||||
def forward(ctx, input, bias):
|
||||
ctx.save_for_backward(input, bias)
|
||||
return bias_gelu(input, bias)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, bias = ctx.saved_tensors
|
||||
tmp = bias_gelu_back(grad_output, input, bias)
|
||||
return tmp, tmp
|
||||
|
||||
|
||||
bias_gelu_impl = GeLUFunction.apply
|
||||
|
||||
# this function is tanh approximation of gelu
|
||||
# actual gelu is:
|
||||
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
|
||||
@torch.jit.script
|
||||
def gelu_fwd(x):
|
||||
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype)
|
||||
|
||||
# gradient of tanh approximation of gelu
|
||||
# gradient of actual gelu is:
|
||||
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
|
||||
@torch.jit.script
|
||||
def gelu_bwd(g, x):
|
||||
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
||||
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
|
||||
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
|
||||
return (ff * g).to(dtype=x.dtype)
|
||||
|
||||
|
||||
class FastGeLUFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# bias is an optional argument
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return gelu_fwd(input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
tmp = gelu_bwd(grad_output, input)
|
||||
return tmp
|
||||
|
||||
fast_gelu_impl = FastGeLUFunction.apply
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def relu_bwd(g, x):
|
||||
return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def sqrelu_fwd(x):
|
||||
r = F.relu(x)
|
||||
return (r * r).to(dtype=x.dtype)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def sqrelu_bwd(g, x):
|
||||
return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
|
||||
527
pkgs/xformers/_flash_attn/ops/fused_dense.py
Normal file
527
pkgs/xformers/_flash_attn/ops/fused_dense.py
Normal file
@@ -0,0 +1,527 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
|
||||
# We make it work with pytorch amp and with bfloat16.
|
||||
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
|
||||
from typing import Optional
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
# import fused_dense_cuda # from apex
|
||||
import fused_dense_lib as fused_dense_cuda
|
||||
|
||||
from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_fwd, sqrelu_bwd
|
||||
from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all_reduce_raw
|
||||
from flash_attn.utils.distributed import reduce_scatter, all_reduce
|
||||
|
||||
|
||||
class FusedDenseFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x, weight, bias, return_residual=False, process_group=None,
|
||||
sequence_parallel=True):
|
||||
"""
|
||||
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
|
||||
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
|
||||
"""
|
||||
ctx.compute_weight_gradient = weight.requires_grad
|
||||
ctx.return_residual = return_residual
|
||||
ctx.process_group = process_group
|
||||
ctx.sequence_parallel = sequence_parallel
|
||||
|
||||
if torch.is_autocast_enabled():
|
||||
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
||||
x = x.contiguous()
|
||||
if process_group is not None and sequence_parallel:
|
||||
# We want to kick off the all_gather early, before weight dtype conversion
|
||||
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
||||
else:
|
||||
total_x = x
|
||||
|
||||
if torch.is_autocast_enabled():
|
||||
weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
|
||||
bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
|
||||
weight = weight.contiguous()
|
||||
if process_group is not None and sequence_parallel:
|
||||
handle_x.wait()
|
||||
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
|
||||
if min(batch_dim, n, *weight.shape) > 65535 * 32:
|
||||
raise RuntimeError('fused_dense only supports matrix dims <= 2M')
|
||||
output = F.linear(total_x, weight, bias)
|
||||
if ctx.compute_weight_gradient:
|
||||
ctx.save_for_backward(x, weight)
|
||||
else:
|
||||
ctx.save_for_backward(weight)
|
||||
return output if not return_residual else (output, x)
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output, *args):
|
||||
grad_output = grad_output.contiguous()
|
||||
if ctx.return_residual:
|
||||
grad_input, = args
|
||||
grad_input = grad_input.contiguous()
|
||||
process_group = ctx.process_group
|
||||
sequence_parallel = ctx.sequence_parallel
|
||||
if ctx.compute_weight_gradient:
|
||||
x, weight = ctx.saved_tensors
|
||||
if process_group is not None and sequence_parallel:
|
||||
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
||||
else:
|
||||
total_x = x
|
||||
else:
|
||||
weight, = ctx.saved_tensors
|
||||
total_x = None
|
||||
batch_shape = grad_output.shape[:-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
||||
if ctx.needs_input_grad[0]:
|
||||
if not ctx.return_residual:
|
||||
grad_input = F.linear(grad_output, weight.t())
|
||||
else:
|
||||
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
|
||||
grad_output, weight)
|
||||
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
||||
if process_group is not None:
|
||||
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
|
||||
grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
|
||||
else:
|
||||
grad_input = None
|
||||
if ctx.needs_input_grad[1]:
|
||||
assert ctx.compute_weight_gradient
|
||||
if process_group is not None and sequence_parallel:
|
||||
handle_x.wait()
|
||||
grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
|
||||
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
|
||||
)
|
||||
else:
|
||||
grad_weight = None
|
||||
grad_bias = grad_output if ctx.needs_input_grad[2] else None
|
||||
if process_group is not None and ctx.needs_input_grad[0]:
|
||||
handle_grad_input.wait()
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
|
||||
|
||||
def fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None,
|
||||
return_residual: bool = False, process_group: Optional[ProcessGroup] = None,
|
||||
sequence_parallel: bool = True):
|
||||
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
|
||||
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
|
||||
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
|
||||
return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group,
|
||||
sequence_parallel)
|
||||
else:
|
||||
assert process_group is None
|
||||
out = F.linear(x, weight, bias)
|
||||
return out if not return_residual else (out, x)
|
||||
|
||||
|
||||
class FusedDense(nn.Linear):
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
return_residual: bool = False, device=None, dtype=None) -> None:
|
||||
super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
|
||||
self.return_residual = return_residual
|
||||
|
||||
def forward(self, x, process_group=None):
|
||||
"""
|
||||
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
|
||||
we do an all_gather of x before doing the matmul.
|
||||
"""
|
||||
return fused_dense_func(x, self.weight, self.bias, return_residual=self.return_residual,
|
||||
process_group=process_group)
|
||||
|
||||
|
||||
class ColumnParallelLinear(nn.Linear):
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup,
|
||||
bias: bool = True, sequence_parallel=True, device=None, dtype=None) -> None:
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
if out_features % world_size != 0:
|
||||
raise ValueError(f'out_features ({out_features}) must be divisible by '
|
||||
f'world_size ({world_size})')
|
||||
super().__init__(in_features, out_features // world_size, bias=bias,
|
||||
device=device, dtype=dtype)
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = sequence_parallel
|
||||
|
||||
def forward(self, x):
|
||||
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
||||
# we do an all_gather of x before doing the matmul.
|
||||
# If not, then the input is already gathered.
|
||||
return fused_dense_func(x, self.weight, self.bias, process_group=self.process_group,
|
||||
sequence_parallel=self.sequence_parallel)
|
||||
|
||||
|
||||
class RowParallelLinear(nn.Linear):
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup,
|
||||
bias: bool = True, sequence_parallel=True, device=None, dtype=None) -> None:
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
rank = torch.distributed.get_rank(process_group)
|
||||
if in_features % world_size != 0:
|
||||
raise ValueError(f'in_features ({in_features}) must be divisible by '
|
||||
f'world_size ({world_size})')
|
||||
# Only rank 0 will have bias
|
||||
super().__init__(in_features // world_size, out_features, bias=bias and rank == 0,
|
||||
device=device, dtype=dtype)
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = sequence_parallel
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
|
||||
a reduce_scatter of the result.
|
||||
"""
|
||||
out = fused_dense_func(x, self.weight, self.bias)
|
||||
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
||||
return reduce_fn(out, self.process_group)
|
||||
|
||||
|
||||
class FusedMLPFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x, weight1, bias1, weight2, bias2, activation='gelu_approx', save_pre_act=True,
|
||||
return_residual=False, checkpoint_lvl=0, heuristic=0, process_group=None,
|
||||
sequence_parallel=True):
|
||||
"""
|
||||
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
|
||||
with sequence parallelism: we do an all_gather of x before doing the matmul.
|
||||
If sequence_parallel=False, then the input is already gathered.
|
||||
|
||||
checkpoint_lvl:
|
||||
0: no recomputation in the bwd
|
||||
1: recompute gelu_out / relu_out in the bwd
|
||||
2: recompute pre_act and gelu_out / relu_out in the bwd
|
||||
"""
|
||||
assert -1 <= heuristic <= 4
|
||||
assert activation in ['gelu_approx', 'relu', 'sqrelu']
|
||||
if activation == 'sqrelu':
|
||||
assert heuristic == -1
|
||||
if not save_pre_act:
|
||||
checkpoint_lvl = 2
|
||||
assert checkpoint_lvl in [0, 1, 2]
|
||||
ctx.return_residual = return_residual
|
||||
ctx.process_group = process_group
|
||||
ctx.sequence_parallel = sequence_parallel
|
||||
ctx.checkpoint_lvl = checkpoint_lvl
|
||||
ctx.activation = activation
|
||||
ctx.heuristic = heuristic
|
||||
|
||||
if torch.is_autocast_enabled():
|
||||
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
||||
x = x.contiguous()
|
||||
if process_group is not None and sequence_parallel:
|
||||
# We want to kick off the all_gather early, before weight dtype conversion
|
||||
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
||||
else:
|
||||
total_x = x
|
||||
|
||||
if torch.is_autocast_enabled():
|
||||
dtype = torch.get_autocast_gpu_dtype()
|
||||
weight1, weight2 = [a.to(dtype=dtype) for a in [weight1, weight2]]
|
||||
bias1 = bias1.to(dtype=dtype) if bias1 is not None else None
|
||||
bias2 = bias2.to(dtype=dtype) if bias2 is not None else None
|
||||
weight1 = weight1.contiguous()
|
||||
bias1 = bias1.contiguous() if bias1 is not None else None
|
||||
weight2 = weight2.contiguous()
|
||||
bias2 = bias2.contiguous() if bias2 is not None else None
|
||||
if process_group is not None and sequence_parallel:
|
||||
handle_x.wait()
|
||||
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
|
||||
if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32:
|
||||
raise RuntimeError('fused_dense only supports matrix dims <= 2M')
|
||||
if heuristic == -1:
|
||||
pre_act = F.linear(total_x, weight1, bias1)
|
||||
activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
|
||||
else (sqrelu_fwd if activation == 'sqrelu' else F.relu))
|
||||
with torch.jit.fuser('fuser2'):
|
||||
output1 = activation_fn(pre_act)
|
||||
# This is before adding bias1
|
||||
# pre_act = F.linear(total_x.reshape(batch_dim, n), weight1)
|
||||
# with torch.jit.fuser('fuser2'):
|
||||
# output1 = bias_gelu(pre_act, bias1)
|
||||
else:
|
||||
is_gelu = activation == 'gelu_approx'
|
||||
output1, *rest = fused_dense_cuda.linear_act_forward(
|
||||
total_x.reshape(batch_dim, n), weight1, bias1, is_gelu, save_pre_act, heuristic
|
||||
)
|
||||
if save_pre_act:
|
||||
pre_act = rest[0]
|
||||
output2 = F.linear(output1, weight2, bias2)
|
||||
if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == 'relu'):
|
||||
# For RELU the pre_act is very small (just a bit-mask) so we just save it
|
||||
ctx.save_for_backward(x, weight1, weight2, pre_act, output1)
|
||||
elif checkpoint_lvl == 1:
|
||||
ctx.save_for_backward(x, weight1, weight2, pre_act)
|
||||
elif checkpoint_lvl == 2:
|
||||
ctx.save_for_backward(x, weight1, weight2, bias1)
|
||||
output2 = output2.reshape(*batch_shape, output2.shape[-1])
|
||||
return output2 if not return_residual else (output2, x)
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output, *args):
|
||||
grad_output = grad_output.contiguous()
|
||||
checkpoint_lvl = ctx.checkpoint_lvl
|
||||
activation = ctx.activation
|
||||
activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
|
||||
else (sqrelu_fwd if activation == 'sqrelu' else F.relu))
|
||||
if ctx.return_residual:
|
||||
grad_input, = args
|
||||
grad_input = grad_input.contiguous()
|
||||
process_group = ctx.process_group
|
||||
sequence_parallel = ctx.sequence_parallel
|
||||
x, weight1, weight2, *rest = ctx.saved_tensors
|
||||
if process_group is None or not sequence_parallel:
|
||||
total_x = x
|
||||
batch_shape = grad_output.shape[:-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
if checkpoint_lvl in [0, 1]:
|
||||
if process_group is not None and sequence_parallel:
|
||||
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
||||
if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == 'relu'):
|
||||
pre_act, output1 = rest
|
||||
elif checkpoint_lvl == 1:
|
||||
pre_act, = rest
|
||||
with torch.jit.fuser('fuser2'):
|
||||
output1 = activation_fn(pre_act)
|
||||
elif checkpoint_lvl == 2:
|
||||
bias1, = rest
|
||||
if process_group is not None and sequence_parallel:
|
||||
total_x, _ = all_gather_raw(x, process_group)
|
||||
if ctx.heuristic == -1:
|
||||
pre_act = F.linear(total_x, weight1, bias1)
|
||||
with torch.jit.fuser('fuser2'):
|
||||
output1 = activation_fn(pre_act)
|
||||
else:
|
||||
output1, pre_act = fused_dense_cuda.linear_act_forward(
|
||||
total_x.reshape(batch_dim, total_x.shape[-1]), weight1, bias1,
|
||||
activation == 'gelu_approx', True, ctx.heuristic
|
||||
)
|
||||
|
||||
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
||||
output1 = output1.reshape(batch_dim, output1.shape[-1])
|
||||
pre_act = pre_act.reshape(batch_dim, pre_act.shape[-1])
|
||||
if ctx.needs_input_grad[3]:
|
||||
grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(
|
||||
output1, grad_output, ctx.needs_input_grad[4]
|
||||
)
|
||||
else:
|
||||
grad_weight2 = None
|
||||
grad_bias2 = grad_output if ctx.needs_input_grad[4] else None
|
||||
if ctx.heuristic == -1:
|
||||
# grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act)
|
||||
grad_output1 = F.linear(grad_output, weight2.t())
|
||||
activation_grad_fn = (gelu_bwd if activation == 'gelu_approx'
|
||||
else (sqrelu_bwd if activation == 'sqrelu' else relu_bwd))
|
||||
with torch.jit.fuser('fuser2'):
|
||||
grad_pre_act = activation_grad_fn(grad_output1, pre_act)
|
||||
else:
|
||||
# The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't
|
||||
# just compute gelu/relu grad
|
||||
grad_pre_act, grad_bias1 = fused_dense_cuda.bias_act_linear_dgrad_bgrad(
|
||||
weight2, grad_output, pre_act, activation == 'gelu_approx', ctx.heuristic
|
||||
)
|
||||
if not ctx.needs_input_grad[2]:
|
||||
grad_bias1 = None
|
||||
if ctx.needs_input_grad[0]:
|
||||
if not ctx.return_residual:
|
||||
grad_input = F.linear(grad_pre_act, weight1.t())
|
||||
else:
|
||||
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
|
||||
grad_pre_act, weight1)
|
||||
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
||||
if process_group is not None:
|
||||
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
|
||||
grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
|
||||
else:
|
||||
grad_input = None
|
||||
if ctx.heuristic == -1:
|
||||
if ctx.needs_input_grad[1]:
|
||||
if process_group is not None and sequence_parallel:
|
||||
handle_x.wait()
|
||||
grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad(
|
||||
total_x.reshape(batch_dim, total_x.shape[-1]), grad_pre_act,
|
||||
ctx.needs_input_grad[2]
|
||||
)
|
||||
else:
|
||||
grad_weight1 = None
|
||||
grad_bias1 = grad_pre_act if ctx.needs_input_grad[2] else None
|
||||
else:
|
||||
if ctx.needs_input_grad[1]:
|
||||
if process_group is not None and sequence_parallel:
|
||||
handle_x.wait()
|
||||
grad_weight1 = F.linear(grad_pre_act.t(),
|
||||
total_x.reshape(batch_dim, total_x.shape[-1]).t())
|
||||
else:
|
||||
grad_weight1 = None
|
||||
if process_group is not None and ctx.needs_input_grad[0]:
|
||||
handle_grad_input.wait()
|
||||
return (grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2,
|
||||
None, None, None, None, None, None, None)
|
||||
|
||||
|
||||
def fused_mlp_func(
|
||||
x: Tensor, weight1: Tensor, weight2: Tensor, bias1: Optional[Tensor] = None,
|
||||
bias2: Optional[Tensor] = None, activation: str = 'gelu_approx',
|
||||
save_pre_act: bool = True, return_residual: bool = False,
|
||||
checkpoint_lvl: int = 0, heuristic: int = 0,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
sequence_parallel: bool = True
|
||||
):
|
||||
assert activation in ['gelu_approx', 'relu', 'sqrelu']
|
||||
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
|
||||
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
|
||||
# If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu)
|
||||
dim_eligible = not save_pre_act or (x.shape[-1] % (128 if activation == 'relu' else 8) == 0)
|
||||
if (x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda)
|
||||
and (bias2 is None or bias2.is_cuda) and dtype_eligible and dim_eligible):
|
||||
return FusedMLPFunc.apply(
|
||||
x, weight1, bias1, weight2, bias2, activation, save_pre_act, return_residual,
|
||||
checkpoint_lvl, heuristic, process_group, sequence_parallel
|
||||
)
|
||||
else:
|
||||
assert process_group is None
|
||||
pre_act = F.linear(x, weight1, bias1)
|
||||
activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
|
||||
else partial(F.relu, inplace=True))
|
||||
output1 = activation_fn(pre_act)
|
||||
output2 = F.linear(output1, weight2, bias2)
|
||||
return output2 if not return_residual else (output2, x)
|
||||
|
||||
|
||||
class FusedMLP(nn.Module):
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, bias1=True,
|
||||
bias2=True, activation='gelu_approx', return_residual=False,
|
||||
checkpoint_lvl=0, heuristic='auto', device=None, dtype=None):
|
||||
"""
|
||||
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
|
||||
we do an all_gather of x before doing the matmul, gelu, then matmul.
|
||||
Finally we do a reduce_scatter of the output.
|
||||
|
||||
checkpoint_lvl (increasing lvl means slower but more memory saving):
|
||||
0: no recomputation in the bwd
|
||||
1: recompute gelu_out in the bwd
|
||||
2: recompute pre_act and gelu_out in the bwd
|
||||
heuristic:
|
||||
-1: don't fuse gemm + gelu (separate kernel)
|
||||
0..4: use this heuristic for the algo section in the fused gemm + gelu
|
||||
'auto': heuristic will be picked automatically:
|
||||
For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
|
||||
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
|
||||
For H100, we set heuristic=-1 for both fp16 and bf16 as the fused cuBlasLt implementation
|
||||
is slower than the unfused version.
|
||||
return_residual: whether to return the input x along with the output. This is for
|
||||
performance reason: for post-norm architecture, returning the input allows us
|
||||
to fuse the backward of nn.Linear with the residual connection.
|
||||
"""
|
||||
assert checkpoint_lvl in [0, 1, 2]
|
||||
assert activation in ['gelu_approx', 'relu', 'sqrelu']
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features * 4
|
||||
self.activation = activation
|
||||
self.return_residual = return_residual
|
||||
self.checkpoint_lvl = checkpoint_lvl
|
||||
self.heuristic = heuristic if activation != 'sqrelu' else -1
|
||||
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
|
||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
||||
|
||||
def forward(self, x, process_group=None):
|
||||
dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()
|
||||
if self.heuristic == 'auto':
|
||||
if self.activation == 'gelu_approx':
|
||||
if torch.cuda.get_device_capability('cuda') == (9, 0):
|
||||
heuristic = -1
|
||||
else:
|
||||
cuda_ver = tuple(map(int, torch.version.cuda.split('.')))
|
||||
heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
|
||||
else:
|
||||
heuristic = 0
|
||||
else:
|
||||
heuristic = self.heuristic
|
||||
out = fused_mlp_func(
|
||||
x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias,
|
||||
activation=self.activation, save_pre_act=self.training,
|
||||
return_residual=self.return_residual, checkpoint_lvl=self.checkpoint_lvl,
|
||||
heuristic=heuristic, process_group=process_group
|
||||
)
|
||||
if self.return_residual:
|
||||
out, x = out
|
||||
if process_group is not None:
|
||||
out = reduce_scatter(out, process_group)
|
||||
return out if not self.return_residual else (out, x)
|
||||
|
||||
|
||||
class ParallelFusedMLP(nn.Module):
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None,
|
||||
activation='gelu_approx', process_group: ProcessGroup = None,
|
||||
bias1=True, bias2=True, sequence_parallel=True, checkpoint_lvl=0, heuristic='auto',
|
||||
device=None, dtype=None):
|
||||
"""
|
||||
process_group is required. We're doing Tensor Parallel with sequence parallelism:
|
||||
we do an all_gather of x before doing the matmul, gelu, then matmul.
|
||||
Finally we do a reduce_scatter of the output.
|
||||
|
||||
checkpoint_lvl (increasing lvl means slower but more memory saving):
|
||||
0: no recomputation in the bwd
|
||||
1: recompute gelu_out in the bwd
|
||||
2: recompute pre_act and gelu_out in the bwd
|
||||
heuristic:
|
||||
-1: don't fuse gemm + gelu (separate kernel)
|
||||
0..4: use this heuristic for the algo section in the fused gemm + gelu
|
||||
'auto': heuristic will be picked automatically:
|
||||
For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
|
||||
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
|
||||
"""
|
||||
assert checkpoint_lvl in [0, 1, 2]
|
||||
assert activation in ['gelu_approx', 'relu', 'sqrelu']
|
||||
assert process_group is not None
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features * 4
|
||||
self.activation = activation
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = sequence_parallel
|
||||
self.checkpoint_lvl = checkpoint_lvl
|
||||
self.heuristic = heuristic if activation != 'sqrelu' else -1
|
||||
self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group,
|
||||
bias=bias1, **factory_kwargs)
|
||||
self.fc2 = RowParallelLinear(hidden_features, out_features, process_group,
|
||||
bias=bias2, **factory_kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()
|
||||
if self.heuristic == 'auto':
|
||||
if self.activation == 'gelu_approx':
|
||||
cuda_ver = tuple(map(int, torch.version.cuda.split('.')))
|
||||
heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
|
||||
else:
|
||||
heuristic = 0
|
||||
else:
|
||||
heuristic = self.heuristic
|
||||
out = fused_mlp_func(
|
||||
x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias,
|
||||
activation=self.activation, save_pre_act=self.training,
|
||||
checkpoint_lvl=self.checkpoint_lvl, heuristic=heuristic,
|
||||
process_group=self.process_group,
|
||||
sequence_parallel=self.sequence_parallel
|
||||
)
|
||||
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
||||
return reduce_fn(out, self.process_group)
|
||||
375
pkgs/xformers/_flash_attn/ops/layer_norm.py
Normal file
375
pkgs/xformers/_flash_attn/ops/layer_norm.py
Normal file
@@ -0,0 +1,375 @@
|
||||
# Copyright (c) 2022, Tri Dao.
|
||||
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
|
||||
|
||||
import torch
|
||||
from torch.nn import init
|
||||
|
||||
import dropout_layer_norm
|
||||
|
||||
|
||||
def maybe_align(x, alignment_in_bytes=16):
|
||||
"""Assume that x already has last dim divisible by alignment_in_bytes
|
||||
"""
|
||||
# TD [2023-07-04] I'm not 100% sure that clone will align the memory
|
||||
# https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
|
||||
return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone()
|
||||
|
||||
|
||||
def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscale, dropout_p,
|
||||
epsilon, residual_in_fp32=False, is_rms_norm=False):
|
||||
""" Assume that arguments are contiguous and aligned to 16 bytes
|
||||
"""
|
||||
hidden_size = gamma.numel()
|
||||
x0mat = x0.view((-1, hidden_size))
|
||||
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
|
||||
rowscale = rowscale.view(-1) if rowscale is not None else None
|
||||
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
x0mat, residualmat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon,
|
||||
1.0, 0, None, residual_in_fp32, is_rms_norm
|
||||
)
|
||||
# dmask is None if dropout_p == 0.0
|
||||
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
|
||||
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
|
||||
|
||||
|
||||
def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale,
|
||||
dropout_p, has_residual, is_rms_norm=False):
|
||||
""" Assume that arguments are contiguous and aligned to 16 bytes
|
||||
dx == None means that it was a post-norm architecture
|
||||
(x = drop(x0) + residual was not returned in the fwd).
|
||||
x0 must not be None if we have colscale.
|
||||
"""
|
||||
hidden_size = gamma.numel()
|
||||
xmat = x.view((-1, hidden_size))
|
||||
dzmat = dz.view(xmat.shape)
|
||||
dxmat = dx.view(xmat.shape) if dx is not None else None
|
||||
x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
|
||||
rowscale = rowscale.view(-1) if rowscale is not None else None
|
||||
if colscale is not None:
|
||||
assert x0 is not None, 'x0 is required to compute the gradient of colscale'
|
||||
dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
|
||||
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None,
|
||||
dropout_p, 1.0, 0, has_residual, is_rms_norm
|
||||
)
|
||||
# dresidualmat is None if not has_residual
|
||||
if colscale is None:
|
||||
return dx0mat, dresidualmat, dgamma, dbeta
|
||||
else:
|
||||
dcolscale = rest[0]
|
||||
return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
|
||||
|
||||
|
||||
def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale, x0_subset,
|
||||
out_subset, dropout_p, epsilon, rowscale_const,
|
||||
out_numrows, residual_in_fp32=False, is_rms_norm=False):
|
||||
""" Assume that arguments are contiguous and aligned to 16 bytes
|
||||
"""
|
||||
hidden_size = gamma.numel()
|
||||
x0mat = x0.view((-1, hidden_size))
|
||||
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
|
||||
x0_subset = x0_subset.view(-1) if x0_subset is not None else None
|
||||
out_subset = out_subset.view(-1) if out_subset is not None else None
|
||||
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
x0mat, residualmat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon,
|
||||
rowscale_const, out_numrows, None, residual_in_fp32, is_rms_norm
|
||||
)
|
||||
# dmask is None if dropout_p == 0.0
|
||||
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
|
||||
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
|
||||
|
||||
|
||||
def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale,
|
||||
x0_subset, out_subset, dropout_p, rowscale_const,
|
||||
x0_numrows, has_residual, is_rms_norm=False):
|
||||
""" Assume that arguments are contiguous and aligned to 16 bytes
|
||||
dx == None means that it was a post-norm architecture
|
||||
(x = drop(x0) + residual was not returned in the fwd).
|
||||
x0 must not be None if we have colscale.
|
||||
"""
|
||||
hidden_size = gamma.numel()
|
||||
xmat = x.view((-1, hidden_size))
|
||||
dzmat = dz.view(-1, hidden_size)
|
||||
dxmat = dx.view(xmat.shape) if dx is not None else None
|
||||
x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
|
||||
x0_subset = x0_subset.view(-1) if x0_subset is not None else None
|
||||
out_subset = out_subset.view(-1) if out_subset is not None else None
|
||||
if colscale is not None:
|
||||
assert x0 is not None, 'x0 is required to compute the gradient of colscale'
|
||||
dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
|
||||
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset,
|
||||
dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm
|
||||
)
|
||||
# dresidualmat is None if not has_residual
|
||||
if colscale is None:
|
||||
return dx0mat, dresidualmat, dgamma, dbeta
|
||||
else:
|
||||
dcolscale = rest[0]
|
||||
return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
|
||||
|
||||
|
||||
def _dropout_add_layer_norm_parallel_residual_forward(
|
||||
x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p,
|
||||
epsilon, residual_in_fp32=False, is_rms_norm=False
|
||||
):
|
||||
""" Assume that arguments are contiguous and aligned to 16 bytes
|
||||
"""
|
||||
hidden_size = gamma0.numel()
|
||||
x0mat = x0.view((-1, hidden_size))
|
||||
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
|
||||
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
|
||||
z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd(
|
||||
x0mat, x1mat, residualmat, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
|
||||
None, residual_in_fp32, is_rms_norm
|
||||
)
|
||||
# dmask0 and dmask1 are None if dropout_p == 0.0
|
||||
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
|
||||
return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma
|
||||
|
||||
|
||||
def _dropout_add_layer_norm_parallel_residual_backward(
|
||||
dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1,
|
||||
dropout_p, has_x1, has_residual, is_rms_norm=False
|
||||
):
|
||||
""" Assume that arguments are contiguous and aligned to 16 bytes
|
||||
dx == None means that it was a post-norm architecture
|
||||
(x = drop(x0) + residual was not returned in the fwd).
|
||||
"""
|
||||
hidden_size = gamma0.numel()
|
||||
xmat = x.view((-1, hidden_size))
|
||||
dz0mat = dz0.view(xmat.shape)
|
||||
dz1mat = dz1.view(xmat.shape) if dz1 is not None else None
|
||||
dxmat = dx.view(xmat.shape) if dx is not None else None
|
||||
dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1, *rest = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd(
|
||||
dz0mat, dz1mat, dxmat, xmat, dmask0, dmask1, mu, rsigma, gamma0, gamma1,
|
||||
dropout_p, has_x1, has_residual, is_rms_norm
|
||||
)
|
||||
# dresidualmat is None if not has_residual
|
||||
return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1
|
||||
|
||||
|
||||
class DropoutAddLayerNormFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
|
||||
residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
|
||||
x0 = maybe_align(x0.contiguous(), 16)
|
||||
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
|
||||
gamma = maybe_align(gamma.contiguous(), 16)
|
||||
beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
|
||||
rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None
|
||||
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
|
||||
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
|
||||
x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
|
||||
residual_in_fp32, is_rms_norm
|
||||
)
|
||||
# Only need to save x0 if we need to compute gradient wrt colscale
|
||||
x0_saved = x0 if colscale is not None else None
|
||||
ctx.save_for_backward(xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale)
|
||||
ctx.prenorm = prenorm
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.has_residual = residual is not None
|
||||
ctx.is_rms_norm = is_rms_norm
|
||||
ctx.has_beta = beta is not None
|
||||
if not return_dmask:
|
||||
return (zmat.view(x0.shape) if not prenorm
|
||||
else (zmat.view(x0.shape), xmat.view(x0.shape)))
|
||||
else:
|
||||
dmask = (dmask.view(x0.shape) if dropout_p > 0.
|
||||
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
|
||||
ctx.mark_non_differentiable(dmask)
|
||||
return ((zmat.view(x0.shape), dmask) if not prenorm
|
||||
else (zmat.view(x0.shape), xmat.view(x0.shape), dmask))
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dz, *args):
|
||||
# assert dz.is_contiguous()
|
||||
dz = maybe_align(dz.contiguous(), 16) # this happens!
|
||||
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
|
||||
x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
|
||||
# x0 is None if colscale is None
|
||||
dropout_p = ctx.dropout_p
|
||||
has_residual = ctx.has_residual
|
||||
dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
|
||||
dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual,
|
||||
ctx.is_rms_norm
|
||||
)
|
||||
dx0 = dx0mat.view(x.shape)
|
||||
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
|
||||
dcolscale = rest[0] if colscale is not None else None
|
||||
return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None,
|
||||
None, None, None, None, None)
|
||||
|
||||
|
||||
class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
|
||||
rowscale_const, out_numrows, residual_in_fp32=False,
|
||||
prenorm=False, is_rms_norm=False, return_dmask=False):
|
||||
x0 = maybe_align(x0.contiguous(), 16)
|
||||
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
|
||||
gamma = maybe_align(gamma.contiguous(), 16)
|
||||
beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
|
||||
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
|
||||
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
|
||||
x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
|
||||
rowscale_const, out_numrows, residual_in_fp32, is_rms_norm
|
||||
)
|
||||
# Only need to save x0 if we need to compute gradient wrt colscale
|
||||
x0_saved = x0 if colscale is not None else None
|
||||
x_shape = (-1, *x0.shape[1:])
|
||||
ctx.save_for_backward(xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale,
|
||||
x0_subset, out_subset)
|
||||
ctx.prenorm = prenorm
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.rowscale_const = rowscale_const
|
||||
ctx.x0_numrows = x0.shape[:-1].numel()
|
||||
ctx.has_residual = residual is not None
|
||||
ctx.is_rms_norm = is_rms_norm
|
||||
ctx.has_beta = beta is not None
|
||||
z_shape = (-1, *x0.shape[1:])
|
||||
if not return_dmask:
|
||||
return (zmat.view(z_shape) if not prenorm
|
||||
else (zmat.view(z_shape), xmat.view(x0.shape)))
|
||||
else:
|
||||
z = zmat.view(z_shape)
|
||||
dmask = (dmask.view(x0.shape) if dropout_p > 0.
|
||||
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
|
||||
ctx.mark_non_differentiable(dmask)
|
||||
return ((z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask))
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dz, *args):
|
||||
# assert dz.is_contiguous()
|
||||
dz = maybe_align(dz.contiguous(), 16) # this happens!
|
||||
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
|
||||
x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors
|
||||
# x0 is None if colscale is None
|
||||
dropout_p = ctx.dropout_p
|
||||
has_residual = ctx.has_residual
|
||||
dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
|
||||
dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p,
|
||||
ctx.rowscale_const, ctx.x0_numrows, has_residual, ctx.is_rms_norm
|
||||
)
|
||||
dx0 = dx0mat.view(-1, *x.shape[1:])
|
||||
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
|
||||
dcolscale = rest[0] if colscale is not None else None
|
||||
return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None,
|
||||
None, None, None, None, None, None, None, None)
|
||||
|
||||
|
||||
class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
|
||||
residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
|
||||
x0 = maybe_align(x0.contiguous(), 16)
|
||||
x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None
|
||||
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
|
||||
gamma0 = maybe_align(gamma0.contiguous(), 16)
|
||||
beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None
|
||||
gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None
|
||||
beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None
|
||||
z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = _dropout_add_layer_norm_parallel_residual_forward(
|
||||
x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
|
||||
residual_in_fp32, is_rms_norm
|
||||
)
|
||||
ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma)
|
||||
ctx.prenorm = prenorm
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.has_x1 = x1 is not None
|
||||
ctx.has_residual = residual is not None
|
||||
ctx.is_rms_norm = is_rms_norm
|
||||
ctx.has_beta = beta0 is not None
|
||||
z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None)
|
||||
if not return_dmask:
|
||||
return z if not prenorm else (*z, xmat.view(x0.shape))
|
||||
else:
|
||||
dmask0 = (dmask0.view(x0.shape) if dropout_p > 0.
|
||||
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
|
||||
dmask1 = (dmask1.view(x0.shape) if dropout_p > 0. and x1 is not None
|
||||
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
|
||||
ctx.mark_non_differentiable(dmask0)
|
||||
ctx.mark_non_differentiable(dmask1)
|
||||
return (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dz0, dz1, *args):
|
||||
dz0 = maybe_align(dz0.contiguous(), 16) # this happens!
|
||||
dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None
|
||||
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
|
||||
x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors
|
||||
dropout_p = ctx.dropout_p
|
||||
has_x1 = ctx.has_x1
|
||||
has_residual = ctx.has_residual
|
||||
dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 = _dropout_add_layer_norm_parallel_residual_backward(
|
||||
dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1, dropout_p, has_x1,
|
||||
has_residual, ctx.is_rms_norm
|
||||
)
|
||||
dx0 = dx0mat.view(x.shape)
|
||||
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
|
||||
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
|
||||
return (dx0, dx1, dresidual, dgamma0, dbeta0 if ctx.has_beta else None, dgamma1,
|
||||
dbeta1 if ctx.has_beta else None, None, None, None, None, None, None)
|
||||
|
||||
|
||||
def layer_norm(x, weight, bias, epsilon):
|
||||
return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
|
||||
|
||||
|
||||
def dropout_add_layer_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
|
||||
layerscale=None, prenorm=False, residual_in_fp32=False,
|
||||
return_dropout_mask=False):
|
||||
"""residual_in_fp32 only has an effect if residual is None.
|
||||
Otherwise residual dtype is residual.dtype.
|
||||
"""
|
||||
return DropoutAddLayerNormFn.apply(
|
||||
x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
|
||||
False, return_dropout_mask
|
||||
)
|
||||
|
||||
|
||||
def dropout_add_layer_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
|
||||
x0_subset=None, out_subset=None, rowscale_const=1.0,
|
||||
out_numrows=0, prenorm=False, residual_in_fp32=False,
|
||||
return_dropout_mask=False):
|
||||
"""residual_in_fp32 only has an effect if residual is None.
|
||||
Otherwise residual dtype is residual.dtype.
|
||||
"""
|
||||
return DropoutAddLayerNormSubsetFn.apply(
|
||||
x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
|
||||
rowscale_const, out_numrows, residual_in_fp32, prenorm, False, return_dropout_mask
|
||||
)
|
||||
|
||||
|
||||
def dropout_add_layer_norm_parallel_residual(
|
||||
x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, prenorm=False,
|
||||
residual_in_fp32=False, return_dropout_mask=False
|
||||
):
|
||||
"""residual_in_fp32 only has an effect if residual is None.
|
||||
Otherwise residual dtype is residual.dtype.
|
||||
"""
|
||||
return DropoutAddLayerNormParallelResidualFn.apply(
|
||||
x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm,
|
||||
False, return_dropout_mask
|
||||
)
|
||||
|
||||
|
||||
class DropoutAddLayerNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
|
||||
device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.prenorm = prenorm
|
||||
self.p = p
|
||||
self.eps = eps
|
||||
self.residual_in_fp32 = residual_in_fp32
|
||||
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
init.ones_(self.weight)
|
||||
init.zeros_(self.bias)
|
||||
|
||||
def forward(self, x0, residual=None):
|
||||
return dropout_add_layer_norm(x0, residual, self.weight, self.bias,
|
||||
self.p if self.training else 0.0, self.eps,
|
||||
prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)
|
||||
89
pkgs/xformers/_flash_attn/ops/rms_norm.py
Normal file
89
pkgs/xformers/_flash_attn/ops/rms_norm.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# Copyright (c) 2022, Tri Dao.
|
||||
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
|
||||
|
||||
import torch
|
||||
from torch.nn import init
|
||||
|
||||
from flash_attn.ops.layer_norm import DropoutAddLayerNormFn, DropoutAddLayerNormSubsetFn
|
||||
from flash_attn.ops.layer_norm import DropoutAddLayerNormParallelResidualFn
|
||||
|
||||
|
||||
def rms_norm(x, weight, epsilon):
|
||||
return DropoutAddLayerNormFn.apply(x, None, weight, None, None, None, 0.0, epsilon, False,
|
||||
False, True)
|
||||
|
||||
|
||||
def dropout_add_rms_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
|
||||
layerscale=None, prenorm=False, residual_in_fp32=False,
|
||||
return_dropout_mask=False):
|
||||
"""residual_in_fp32 only has an effect if residual is None.
|
||||
Otherwise residual dtype is residual.dtype.
|
||||
"""
|
||||
return DropoutAddLayerNormFn.apply(
|
||||
x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
|
||||
True, return_dropout_mask
|
||||
)
|
||||
|
||||
|
||||
def dropout_add_rms_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
|
||||
x0_subset=None, out_subset=None, rowscale_const=1.0,
|
||||
out_numrows=0, prenorm=False, residual_in_fp32=False,
|
||||
return_dropout_mask=False):
|
||||
"""residual_in_fp32 only has an effect if residual is None.
|
||||
Otherwise residual dtype is residual.dtype.
|
||||
"""
|
||||
return DropoutAddLayerNormSubsetFn.apply(
|
||||
x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
|
||||
rowscale_const, out_numrows, residual_in_fp32, prenorm, True, return_dropout_mask
|
||||
)
|
||||
|
||||
|
||||
def dropout_add_rms_norm_parallel_residual(
|
||||
x0, x1, residual, weight0, bias0, weight1, bias1,
|
||||
dropout_p, epsilon, prenorm=False, residual_in_fp32=False, return_dropout_mask=False
|
||||
):
|
||||
"""residual_in_fp32 only has an effect if residual is None.
|
||||
Otherwise residual dtype is residual.dtype.
|
||||
"""
|
||||
return DropoutAddLayerNormParallelResidualFn.apply(
|
||||
x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm,
|
||||
True, return_dropout_mask
|
||||
)
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.register_parameter('bias', None)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
init.ones_(self.weight)
|
||||
|
||||
def forward(self, x):
|
||||
return rms_norm(x, self.weight, self.eps)
|
||||
|
||||
|
||||
class DropoutAddRMSNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
|
||||
device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.prenorm = prenorm
|
||||
self.p = p
|
||||
self.eps = eps
|
||||
self.residual_in_fp32 = residual_in_fp32
|
||||
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.register_parameter('bias', None)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
init.ones_(self.weight)
|
||||
|
||||
def forward(self, x0, residual=None):
|
||||
return dropout_add_rms_norm(x0, residual, self.weight, None,
|
||||
self.p if self.training else 0.0, self.eps,
|
||||
prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)
|
||||
0
pkgs/xformers/_flash_attn/utils/__init__.py
Normal file
0
pkgs/xformers/_flash_attn/utils/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
146
pkgs/xformers/_flash_attn/utils/benchmark.py
Normal file
146
pkgs/xformers/_flash_attn/utils/benchmark.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# Copyright (c) 2022, Tri Dao.
|
||||
""" Useful functions for writing test code. """
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
|
||||
|
||||
def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, amp=False,
|
||||
amp_dtype=torch.float16, **kwinputs):
|
||||
""" Use Pytorch Benchmark on the forward pass of an arbitrary function. """
|
||||
if verbose:
|
||||
print(desc, '- Forward pass')
|
||||
def fn_amp(*inputs, **kwinputs):
|
||||
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
|
||||
fn(*inputs, **kwinputs)
|
||||
for _ in range(repeats): # warmup
|
||||
fn_amp(*inputs, **kwinputs)
|
||||
t = benchmark.Timer(
|
||||
stmt='fn_amp(*inputs, **kwinputs)',
|
||||
globals={'fn_amp': fn_amp, 'inputs': inputs, 'kwinputs': kwinputs},
|
||||
num_threads=torch.get_num_threads(),
|
||||
)
|
||||
m = t.timeit(repeats)
|
||||
if verbose:
|
||||
print(m)
|
||||
return t, m
|
||||
|
||||
|
||||
def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
|
||||
amp_dtype=torch.float16, **kwinputs):
|
||||
""" Use Pytorch Benchmark on the backward pass of an arbitrary function. """
|
||||
if verbose:
|
||||
print(desc, '- Backward pass')
|
||||
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
|
||||
y = fn(*inputs, **kwinputs)
|
||||
if type(y) is tuple:
|
||||
y = y[0]
|
||||
if grad is None:
|
||||
grad = torch.randn_like(y)
|
||||
else:
|
||||
if grad.shape != y.shape:
|
||||
raise RuntimeError('Grad shape does not match output shape')
|
||||
for _ in range(repeats): # warmup
|
||||
y.backward(grad, retain_graph=True)
|
||||
t = benchmark.Timer(
|
||||
stmt='y.backward(grad, retain_graph=True)',
|
||||
globals={'y': y, 'grad': grad},
|
||||
num_threads=torch.get_num_threads(),
|
||||
)
|
||||
m = t.timeit(repeats)
|
||||
if verbose:
|
||||
print(m)
|
||||
return t, m
|
||||
|
||||
|
||||
def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
|
||||
amp_dtype=torch.float16, **kwinputs):
|
||||
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
|
||||
if verbose:
|
||||
print(desc, '- Forward + Backward pass')
|
||||
def f(grad, *inputs, **kwinputs):
|
||||
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
|
||||
y = fn(*inputs, **kwinputs)
|
||||
if type(y) is tuple:
|
||||
y = y[0]
|
||||
if grad is None:
|
||||
grad = torch.randn_like(y)
|
||||
else:
|
||||
if grad.shape != y.shape:
|
||||
raise RuntimeError('Grad shape does not match output shape')
|
||||
y.backward(grad, retain_graph=True)
|
||||
for _ in range(repeats): # warmup
|
||||
f(grad, *inputs, **kwinputs)
|
||||
t = benchmark.Timer(
|
||||
stmt='f(grad, *inputs, **kwinputs)',
|
||||
globals={'f': f, 'fn': fn, 'inputs': inputs, 'grad': grad, 'kwinputs': kwinputs},
|
||||
num_threads=torch.get_num_threads(),
|
||||
)
|
||||
m = t.timeit(repeats)
|
||||
if verbose:
|
||||
print(m)
|
||||
return t, m
|
||||
|
||||
|
||||
def benchmark_all(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
|
||||
amp_dtype=torch.float16, **kwinputs):
|
||||
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
|
||||
return (
|
||||
benchmark_forward(fn, *inputs, repeats=repeats, desc=desc, verbose=verbose,
|
||||
amp=amp, amp_dtype=amp_dtype, **kwinputs),
|
||||
benchmark_backward(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose,
|
||||
amp=amp, amp_dtype=amp_dtype, **kwinputs),
|
||||
benchmark_combined(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose,
|
||||
amp=amp, amp_dtype=amp_dtype, **kwinputs),
|
||||
)
|
||||
|
||||
|
||||
def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False,
|
||||
amp_dtype=torch.float16, cpu=False, verbose=True, **kwinputs):
|
||||
""" Wrap benchmark functions in Pytorch profiler to see CUDA information. """
|
||||
if backward:
|
||||
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
|
||||
g = torch.randn_like(fn(*inputs, **kwinputs))
|
||||
for _ in range(30): # Warm up
|
||||
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
|
||||
if backward:
|
||||
for x in inputs:
|
||||
if isinstance(x, torch.Tensor):
|
||||
x.grad = None
|
||||
# fn(*inputs, **kwinputs) if not backward else fn(*inputs, **kwinputs).backward(g)
|
||||
out = fn(*inputs, **kwinputs)
|
||||
# Backward should be done outside autocast
|
||||
if backward:
|
||||
out.backward(g)
|
||||
activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [torch.profiler.ProfilerActivity.CUDA]
|
||||
with torch.profiler.profile(
|
||||
activities=activities,
|
||||
record_shapes=True,
|
||||
# profile_memory=True,
|
||||
with_stack=True,
|
||||
) as prof:
|
||||
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
|
||||
if backward:
|
||||
for x in inputs:
|
||||
if isinstance(x, torch.Tensor):
|
||||
x.grad = None
|
||||
out = fn(*inputs, **kwinputs)
|
||||
if backward: out.backward(g)
|
||||
if verbose:
|
||||
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
|
||||
print(prof.key_averages().table(row_limit=50))
|
||||
if trace_filename is not None:
|
||||
prof.export_chrome_trace(trace_filename)
|
||||
|
||||
|
||||
def benchmark_memory(fn, *inputs, desc='', verbose=True, **kwinputs):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
fn(*inputs, **kwinputs)
|
||||
torch.cuda.synchronize()
|
||||
mem = torch.cuda.max_memory_allocated() / ((2 ** 20) * 1000)
|
||||
if verbose:
|
||||
print(f'{desc} max memory: {mem}GB')
|
||||
torch.cuda.empty_cache()
|
||||
return mem
|
||||
127
pkgs/xformers/_flash_attn/utils/distributed.py
Normal file
127
pkgs/xformers/_flash_attn/utils/distributed.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
|
||||
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
|
||||
# version of PyTorch. The following 4 lines are for backward compatibility with
|
||||
# older PyTorch.
|
||||
if "all_gather_into_tensor" not in dir(torch.distributed):
|
||||
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
|
||||
if "reduce_scatter_tensor" not in dir(torch.distributed):
|
||||
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
|
||||
|
||||
|
||||
# Raw operation, does not support autograd, but does support async
|
||||
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
output = torch.empty(world_size * input_.shape[0], *input_.shape[1:],
|
||||
dtype=input_.dtype, device=input_.device)
|
||||
handle = torch.distributed.all_gather_into_tensor(output, input_.contiguous(),
|
||||
group=process_group, async_op=async_op)
|
||||
return output, handle
|
||||
|
||||
|
||||
# Raw operation, does not support autograd, but does support async
|
||||
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
assert input_.shape[0] % world_size == 0
|
||||
output = torch.empty(input_.shape[0] // world_size, *input_.shape[1:],
|
||||
dtype=input_.dtype, device=input_.device)
|
||||
handle = torch.distributed.reduce_scatter_tensor(output, input_.contiguous(),
|
||||
group=process_group,
|
||||
async_op=async_op)
|
||||
return output, handle
|
||||
|
||||
|
||||
# Raw operation, does not support autograd, but does support async
|
||||
def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
||||
input_ = input_.contiguous()
|
||||
handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
|
||||
return input_, handle
|
||||
|
||||
|
||||
class AllGatherFunc(torch.autograd.Function):
|
||||
"""Gather the input from sequence parallel region and concatenate."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
||||
ctx.process_group = process_group
|
||||
output, _ = all_gather_raw(input_, process_group)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output: Tensor):
|
||||
grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)
|
||||
return grad_input, None
|
||||
|
||||
|
||||
# Supports autograd, but does not support async
|
||||
all_gather = AllGatherFunc.apply
|
||||
|
||||
|
||||
class ReduceScatterFunc(torch.autograd.Function):
|
||||
"""Reduce scatter the input from the sequence parallel region and concatenate."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
||||
ctx.process_group = process_group
|
||||
output, _ = reduce_scatter_raw(input_, process_group)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output: Tensor):
|
||||
grad_input, _ = all_gather_raw(grad_output, ctx.process_group)
|
||||
return grad_input, None
|
||||
|
||||
|
||||
# Supports autograd, but does not support async
|
||||
reduce_scatter = ReduceScatterFunc.apply
|
||||
|
||||
|
||||
class AllReduceFunc(torch.autograd.Function):
|
||||
"""Gather the input from sequence parallel region and concatenate."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
||||
ctx.process_group = process_group
|
||||
output, _ = all_reduce_raw(input_, process_group)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output: Tensor):
|
||||
return grad_output, None
|
||||
|
||||
|
||||
# Supports autograd, but does not support async
|
||||
all_reduce = AllReduceFunc.apply
|
||||
|
||||
|
||||
def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
|
||||
# We want to iterate over parameters with _shared_params=True in the same order,
|
||||
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
|
||||
pamams_shared = {name: p for name, p in model.named_parameters()
|
||||
if getattr(p, '_shared_params', False)}
|
||||
for _, p in sorted(pamams_shared.items()):
|
||||
with torch.no_grad():
|
||||
# Broadcast needs src to be global rank, not group rank
|
||||
torch.distributed.broadcast(
|
||||
p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
|
||||
)
|
||||
|
||||
|
||||
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
|
||||
def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
|
||||
# We want to iterate over parameters with _sequence_parallel=True in the same order,
|
||||
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
|
||||
params_seqparallel = {name: p for name, p in model.named_parameters()
|
||||
if getattr(p, '_sequence_parallel', False)}
|
||||
grads = [p.grad for _, p in sorted(params_seqparallel.items())]
|
||||
if grads:
|
||||
with torch.no_grad():
|
||||
coalesced = torch._utils._flatten_dense_tensors(grads)
|
||||
torch.distributed.all_reduce(coalesced, group=process_group)
|
||||
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
302
pkgs/xformers/_flash_attn/utils/generation.py
Normal file
302
pkgs/xformers/_flash_attn/utils/generation.py
Normal file
@@ -0,0 +1,302 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
|
||||
from typing import Optional, Union, Sequence, Callable
|
||||
import gc
|
||||
import time
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.profiler import profile, record_function, ProfilerActivity
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceParams:
|
||||
"""Inference parameters that are passed to the main model in order
|
||||
to efficienly calculate and store the context during inference."""
|
||||
max_sequence_len: int
|
||||
max_batch_size: int
|
||||
sequence_len_offset: int = 0
|
||||
batch_size_offset: int = 0
|
||||
key_value_memory_dict: dict = field(default_factory=dict)
|
||||
fused_ft_kernel: bool = False
|
||||
lengths_per_sample: Optional[Tensor] = None
|
||||
|
||||
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
||||
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
|
||||
def modify_logits_for_top_p_filtering(logits, top_p):
|
||||
"""Set the logits for none top-p values to -inf."""
|
||||
if top_p <= 0.0:
|
||||
return
|
||||
# First sort and calculate cumulative sum of probabilities.
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
|
||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits = logits.masked_fill(indices_to_remove, float('-inf'))
|
||||
|
||||
|
||||
def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
|
||||
"""Sample from top-k logits.
|
||||
Arguments:
|
||||
logits: Tensor of shape (batch_size, vocab_size)
|
||||
"""
|
||||
if top_k == 1: # Short-circuit for greedy decoding
|
||||
return logits.argmax(dim=-1)
|
||||
else:
|
||||
if top_p > 0.0:
|
||||
assert top_p <= 1.0, 'top-p should be in (0, 1].'
|
||||
if top_k > 0:
|
||||
top_k = min(top_k, logits.size(-1)) # Safety check
|
||||
logits_top, indices = torch.topk(logits, top_k, dim=-1)
|
||||
logits_top /= temperature
|
||||
modify_logits_for_top_p_filtering(logits_top, top_p)
|
||||
return indices[
|
||||
torch.arange(indices.shape[0], device=indices.device),
|
||||
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
|
||||
]
|
||||
else:
|
||||
logits_top = logits / temperature
|
||||
modify_logits_for_top_p_filtering(logits_top, top_p)
|
||||
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
|
||||
|
||||
|
||||
def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
||||
eos_token_id=None, teacher_outputs=None, vocab_size=None, tensor_parallel=1,
|
||||
fused_ft_kernel=False, cg=False, timing=False):
|
||||
"""Decoding, either greedy or with top-k or top-p sampling.
|
||||
If top-k = 0, don't limit the number of candidates (pure sampling).
|
||||
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
|
||||
then top-p.
|
||||
We assume that all sequences in the same batch have the same length.
|
||||
|
||||
Arguments:
|
||||
input_ids: (batch, seq_len)
|
||||
max_length: int
|
||||
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
|
||||
logits, the next token is taken from the teacher_outputs. Useful for testing.
|
||||
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
|
||||
sequences: (batch, max_length)
|
||||
scores: tuples of (batch, vocab_size)
|
||||
"""
|
||||
batch_size, seqlen_og = input_ids.shape
|
||||
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
|
||||
if cg:
|
||||
assert fused_ft_kernel
|
||||
if not hasattr(model, '_decoding_cache'):
|
||||
model._decoding_cache = None
|
||||
model._decoding_cache = update_graph_cache(
|
||||
model, model._decoding_cache, batch_size, seqlen_og, max_length,
|
||||
tensor_parallel=tensor_parallel
|
||||
)
|
||||
inference_params = model._decoding_cache.inference_params
|
||||
inference_params.max_sequence_len = max_length
|
||||
inference_params.max_batch_size = batch_size
|
||||
inference_params.sequence_len_offset = 0
|
||||
else:
|
||||
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size,
|
||||
fused_ft_kernel=fused_ft_kernel)
|
||||
scores = []
|
||||
with torch.inference_mode():
|
||||
if timing:
|
||||
if tensor_parallel > 1:
|
||||
torch.distributed.barrier()
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
logits = model(input_ids, inference_params=inference_params, last_token_only=True).logits
|
||||
if vocab_size is not None:
|
||||
logits = logits[..., :vocab_size]
|
||||
scores.append(logits if not cg else logits.clone())
|
||||
if teacher_outputs is None or teacher_output_len <= seqlen_og:
|
||||
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
|
||||
else:
|
||||
next_token = teacher_outputs[:, seqlen_og]
|
||||
sequences = [next_token]
|
||||
inference_params.sequence_len_offset = seqlen_og
|
||||
while True:
|
||||
position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset,
|
||||
dtype=torch.long, device=input_ids.device)
|
||||
if not cg:
|
||||
logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
|
||||
inference_params=inference_params, last_token_only=True).logits
|
||||
else:
|
||||
logits = model._decoding_cache.run(rearrange(next_token, 'b -> b 1'), position_ids,
|
||||
inference_params.sequence_len_offset)
|
||||
if vocab_size is not None:
|
||||
logits = logits[..., :vocab_size]
|
||||
scores.append(logits if not cg else logits.clone())
|
||||
if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset + 1:
|
||||
next_token = sample(logits, top_k=top_k, temperature=temperature)
|
||||
else:
|
||||
next_token = teacher_outputs[:, inference_params.sequence_len_offset + 1]
|
||||
sequences.append(next_token)
|
||||
inference_params.sequence_len_offset += 1
|
||||
if eos_token_id is not None and (next_token == eos_token_id).all():
|
||||
break
|
||||
if inference_params.sequence_len_offset >= max_length - 1:
|
||||
break
|
||||
if timing:
|
||||
if tensor_parallel > 1:
|
||||
torch.distributed.barrier()
|
||||
torch.cuda.synchronize()
|
||||
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
|
||||
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
|
||||
return output_cls(
|
||||
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),
|
||||
scores=tuple(scores)
|
||||
)
|
||||
|
||||
|
||||
class GenerationMixin:
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def generate(self, input_ids, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
||||
return_dict_in_generate=False, output_scores=False, **kwargs):
|
||||
output = decode(input_ids, self, max_length, top_k=top_k, top_p=top_p,
|
||||
temperature=temperature, **kwargs)
|
||||
if not output_scores:
|
||||
output.scores = None
|
||||
return output if return_dict_in_generate else output.sequences
|
||||
|
||||
|
||||
def allocate_inference_cache(max_batch_size, max_seqlen, nheads, headdim, layers: Union[int, Sequence],
|
||||
device, dtype=torch.float16):
|
||||
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
packsize = 4 if dtype == torch.float32 else 8
|
||||
assert headdim % packsize == 0
|
||||
k_cache_shape = (max_batch_size, nheads, headdim // packsize, max_seqlen, packsize)
|
||||
v_cache_shape = (max_batch_size, nheads, max_seqlen, headdim)
|
||||
if isinstance(layers, int):
|
||||
layers = range(layers)
|
||||
return {i: (torch.empty(k_cache_shape, device=device, dtype=dtype),
|
||||
torch.empty(v_cache_shape, device=device, dtype=dtype))
|
||||
for i in layers}
|
||||
|
||||
|
||||
def seqlen_to_seqlen_type(seqlen: int) -> int:
|
||||
"""Convert sequence length to a seqlen_type.
|
||||
This is used to determine which cuda graph to use.
|
||||
Arguments:
|
||||
seqlen: int
|
||||
"""
|
||||
return 0 if seqlen < 32 else (1 if seqlen < 2048 else 2)
|
||||
|
||||
|
||||
def seqlen_type_to_max_seqlen(seqlen_type: int) -> int:
|
||||
assert seqlen_type in [0, 1, 2]
|
||||
return 32 if seqlen_type == 0 else (2048 if seqlen_type == 1 else 2**32)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodingCGCache:
|
||||
max_batch_size: int = 0
|
||||
max_seqlen: int = 0
|
||||
device = None
|
||||
dtype = None
|
||||
callables: dict = field(default_factory=dict)
|
||||
mempool = None
|
||||
inference_params: Optional[InferenceParams] = None
|
||||
run: Optional[Callable] = None
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1,
|
||||
dtype=None, n_warmups=2):
|
||||
if cache is None:
|
||||
cache = DecodingCGCache()
|
||||
param_example = next(iter(model.parameters()))
|
||||
device = param_example.device
|
||||
if dtype is None:
|
||||
dtype = param_example.dtype
|
||||
if ((device, dtype) != (cache.device, cache.dtype) or batch_size > cache.max_batch_size
|
||||
or max_seqlen > cache.max_seqlen): # Invalidate the cache
|
||||
cache.callables = {}
|
||||
cache.mempool = None
|
||||
cache.inference_params = None
|
||||
gc.collect()
|
||||
cache.device, cache.dtype = device, dtype
|
||||
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
|
||||
if hasattr(model, 'allocate_inference_cache'):
|
||||
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
|
||||
else:
|
||||
headdim = getattr(model.config, 'head_dim',
|
||||
model.config.hidden_size // model.config.num_attention_heads)
|
||||
inf_cache = allocate_inference_cache(
|
||||
batch_size, max_seqlen, model.config.num_attention_heads // tensor_parallel, headdim,
|
||||
model.config.num_hidden_layers, device, dtype
|
||||
)
|
||||
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
|
||||
cache.inference_params = InferenceParams(
|
||||
max_sequence_len=max_seqlen, max_batch_size=batch_size,
|
||||
sequence_len_offset=seqlen_og, key_value_memory_dict=inf_cache, fused_ft_kernel=True,
|
||||
lengths_per_sample=lengths_per_sample
|
||||
)
|
||||
cache.mempool = torch.cuda.graphs.graph_pool_handle()
|
||||
for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1):
|
||||
if (batch_size, s_type) not in cache.callables:
|
||||
max_seqlen_ = min(max(seqlen_og, seqlen_type_to_max_seqlen(s_type)), max_seqlen)
|
||||
cache.callables[batch_size, s_type] = capture_graph(
|
||||
model, cache.inference_params, batch_size, max_seqlen_, mempool=cache.mempool,
|
||||
n_warmups=n_warmups
|
||||
)
|
||||
|
||||
def dispatch(input_ids, position_ids, seqlen):
|
||||
batch_size = input_ids.shape[0]
|
||||
return cache.callables[batch_size, seqlen_to_seqlen_type(seqlen)](input_ids, position_ids, seqlen)
|
||||
|
||||
cache.run = dispatch
|
||||
cache.inference_params.sequence_len_offset = 0 # Reset so it's not confusing
|
||||
return cache
|
||||
|
||||
|
||||
def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, n_warmups=2):
|
||||
device = next(iter(model.parameters())).device
|
||||
input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
|
||||
position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
|
||||
sequence_len_offset_og = inference_params.sequence_len_offset
|
||||
# TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is
|
||||
# used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample.
|
||||
inference_params.sequence_len_offset = max_seqlen - 1
|
||||
inference_params.lengths_per_sample[:] = max_seqlen - 1
|
||||
|
||||
# Warmup before capture
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s):
|
||||
for _ in range(n_warmups):
|
||||
logits = model(input_ids, position_ids=position_ids, inference_params=inference_params,
|
||||
last_token_only=True).logits
|
||||
s.synchronize()
|
||||
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
|
||||
# which requires that graph launch and non-captured launch to not overlap (I think,
|
||||
# that's how I interpret the documentation). I'm not sure if this is required.
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.barrier()
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
# Captures the graph
|
||||
# To allow capture, automatically sets a side stream as the current stream in the context
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, pool=mempool):
|
||||
logits = model(input_ids, position_ids=position_ids, inference_params=inference_params,
|
||||
last_token_only=True).logits
|
||||
|
||||
def run(new_input_ids, new_position_ids, seqlen):
|
||||
inference_params.lengths_per_sample[:] = seqlen
|
||||
input_ids.copy_(new_input_ids)
|
||||
position_ids.copy_(new_position_ids)
|
||||
graph.replay()
|
||||
return logits
|
||||
|
||||
inference_params.sequence_len_offset = sequence_len_offset_og
|
||||
return run
|
||||
37
pkgs/xformers/_flash_attn/utils/pretrained.py
Normal file
37
pkgs/xformers/_flash_attn/utils/pretrained.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import torch
|
||||
|
||||
from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||
from transformers.utils import is_remote_url
|
||||
from transformers.modeling_utils import load_state_dict
|
||||
from transformers.utils.hub import cached_file, get_checkpoint_shard_files
|
||||
|
||||
|
||||
def state_dict_from_pretrained(model_name, device=None, dtype=None):
|
||||
# If not fp32, then we don't want to load directly to the GPU
|
||||
mapped_device = 'cpu' if dtype not in [torch.float32, None] else device
|
||||
is_sharded = False
|
||||
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
|
||||
_raise_exceptions_for_missing_entries=False)
|
||||
if resolved_archive_file is None:
|
||||
resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME,
|
||||
_raise_exceptions_for_missing_entries=False)
|
||||
if resolved_archive_file is not None:
|
||||
is_sharded = True
|
||||
if resolved_archive_file is None:
|
||||
raise EnvironmentError(f"Model name {model_name} was not found.")
|
||||
if is_sharded:
|
||||
# resolved_archive_file becomes a list of files that point to the different
|
||||
# checkpoint shards in this case.
|
||||
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
|
||||
model_name, resolved_archive_file
|
||||
)
|
||||
state_dict = {}
|
||||
for sharded_file in resolved_archive_file:
|
||||
state_dict.update(torch.load(sharded_file, map_location=mapped_device))
|
||||
else:
|
||||
state_dict = torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device)
|
||||
# Convert dtype before moving to GPU to save memory
|
||||
if dtype is not None:
|
||||
state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
|
||||
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
|
||||
return state_dict
|
||||
Reference in New Issue
Block a user