初始化项目,由ModelHub XC社区提供模型
Model: Finnish-NLP/Ahma-7B Source: Original Platform
This commit is contained in:
228
EasyLM/bpt.py
Normal file
228
EasyLM/bpt.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
An implementation of Blockwise parallel transformer https://arxiv.org/abs/2305.19370
|
||||
Also include a reference implementation of memory-efficient transformer https://arxiv.org/abs/2112.05682
|
||||
"""
|
||||
|
||||
import functools
|
||||
from typing import NamedTuple
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.lax as lax
|
||||
import jax.numpy as jnp
|
||||
from einops import rearrange
|
||||
|
||||
"""
|
||||
Computing ffn blockwise without materializing the large hidden tensor, training
|
||||
4x longer sequences than the memory-efficient transformer.
|
||||
Blockwise parallel transformer https://arxiv.org/abs/2305.19370 Liu et al. 2023
|
||||
"""
|
||||
def blockwise_ffn(remat_ffn, inputs, chunk_size=2048, deterministic=True):
|
||||
# remat_ffn: a rematerialized ffn with policy jax.checkpoint_policies.nothing_saveable()
|
||||
# inputs: (batch, seq_len, dim)
|
||||
# chunk_size: the chunk size to split the sequence
|
||||
inputs = rearrange(inputs, 'b (c n) d -> b c n d', c=chunk_size)
|
||||
def scan_ffn(remat_ffn, carry, hidden_states):
|
||||
outputs = remat_ffn(hidden_states, deterministic=deterministic)
|
||||
return carry, outputs
|
||||
scan_axis = inputs.ndim - 2
|
||||
_, res = nn.scan(
|
||||
scan_ffn,
|
||||
variable_broadcast="params",
|
||||
split_rngs={"params": False, "dropout": True},
|
||||
in_axes=scan_axis,
|
||||
out_axes=scan_axis,
|
||||
)(remat_ffn, None, inputs)
|
||||
res = rearrange(res, 'b c n d -> b (c n) d')
|
||||
return res
|
||||
|
||||
|
||||
"""
|
||||
Compute attention blockwise without materializing the full attention matrix,
|
||||
initially proposed in memory-efficient transformer https://arxiv.org/abs/2112.05682 Rabe et al. 2021;
|
||||
flash attention https://arxiv.org/abs/2205.14135 Dao et al. 2022 proposes a CUDA
|
||||
efficient implementation; blockwise parallel transformer https://arxiv.org/abs/2305.19370
|
||||
Liu et al. 2023 proposes blockwise computing both attention and FFN, enabling 4x
|
||||
longer sequences than memory-efficient/flash-attention and fusion of attention and FFN.
|
||||
"""
|
||||
def blockwise_attn(
|
||||
query, key, value,
|
||||
bias=None,
|
||||
deterministic=True,
|
||||
dropout_rng=None,
|
||||
attn_pdrop=0.0,
|
||||
causal=True,
|
||||
query_chunk_size=2048,
|
||||
key_chunk_size=2048,
|
||||
dtype=jnp.float32,
|
||||
policy=jax.checkpoint_policies.nothing_saveable(),
|
||||
precision=None,
|
||||
float32_logits=True,
|
||||
prevent_cse=True,
|
||||
):
|
||||
# query, key, value: (batch, seq_len, num_heads, dim_per_head)
|
||||
# bias: (batch, seq_len) can be used to mask out attention (e.g. padding)
|
||||
# causal: whether to use causal mask
|
||||
# policy: one of jax.checkpoint_policies
|
||||
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
|
||||
if float32_logits:
|
||||
query = query.astype(jnp.float32)
|
||||
key = key.astype(jnp.float32)
|
||||
|
||||
batch, q_len, num_heads, dim_per_head = query.shape
|
||||
batch, kv_len, num_heads, dim_per_head = key.shape
|
||||
batch, kv_len, num_heads, dim_per_head = value.shape
|
||||
|
||||
num_q = q_len // query_chunk_size
|
||||
num_kv = kv_len // key_chunk_size
|
||||
query = query.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
|
||||
key = key.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
|
||||
value = value.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
|
||||
|
||||
query = jnp.moveaxis(query, 1, 0)
|
||||
key = jnp.moveaxis(key, 1, 0)
|
||||
value = jnp.moveaxis(value, 1, 0)
|
||||
|
||||
if bias is not None:
|
||||
for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)):
|
||||
assert bias_dim == 1 or bias_dim == broadcast_dim
|
||||
if not deterministic and attn_pdrop > 0.0:
|
||||
attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng)
|
||||
attn_dropout = jax.random.bernoulli(attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len))
|
||||
else:
|
||||
attn_dropout = None
|
||||
|
||||
_chunk_bias_fn = functools.partial(
|
||||
_chunk_attention_bias,
|
||||
query_chunk_size, key_chunk_size, bias, deterministic,
|
||||
attn_dropout, attn_pdrop, causal, dtype)
|
||||
|
||||
def scan_attention(args):
|
||||
query_chunk, query_chunk_idx = args
|
||||
|
||||
@functools.partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy)
|
||||
def scan_kv_block(carry, args):
|
||||
key_chunk, value_chunk, key_chunk_idx = args
|
||||
(numerator, denominator, prev_max_score) = carry
|
||||
attn_weights = jnp.einsum('bqhd,bkhd->bqhk', query_chunk, key_chunk, precision=precision)
|
||||
bias_chunk = _chunk_bias_fn(query_chunk_idx, key_chunk_idx)
|
||||
bias_chunk = jnp.moveaxis(bias_chunk, 1, 2)
|
||||
attn_weights = attn_weights + bias_chunk
|
||||
|
||||
max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
|
||||
max_score = jnp.maximum(prev_max_score, max_score)
|
||||
max_score = jax.lax.stop_gradient(max_score)
|
||||
exp_weights = jnp.exp(attn_weights - max_score)
|
||||
exp_values = jnp.einsum(
|
||||
'bqhv,bvhd->bqhd', exp_weights, value_chunk, precision=precision
|
||||
)
|
||||
correction = jnp.exp(prev_max_score - max_score)
|
||||
numerator = numerator * correction + exp_values
|
||||
denominator = denominator * correction + exp_weights.sum(axis=-1, keepdims=True)
|
||||
return Carry(numerator, denominator, max_score), None
|
||||
|
||||
def skip_upper_half(carry, args):
|
||||
key_chunk, value_chunk, key_chunk_idx = args
|
||||
skip_block = jnp.array(False)
|
||||
if causal:
|
||||
skip_block = query_chunk_idx < key_chunk_idx
|
||||
return jax.lax.cond(
|
||||
skip_block,
|
||||
lambda carry, args: (carry, None),
|
||||
scan_kv_block,
|
||||
carry,
|
||||
args,
|
||||
)
|
||||
|
||||
init_carry = Carry(
|
||||
jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
|
||||
jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
|
||||
(-jnp.inf) * jnp.ones((batch, query_chunk_size, num_heads, 1), dtype=query.dtype),
|
||||
)
|
||||
(numerator, denominator, max_score), _ = lax.scan(
|
||||
skip_upper_half, init_carry, xs=(key, value, jnp.arange(0, num_kv))
|
||||
)
|
||||
outputs = (numerator / denominator).astype(dtype)
|
||||
return outputs
|
||||
|
||||
_, res = lax.scan(
|
||||
lambda _, x: ((), scan_attention(x)),
|
||||
(), xs=(query, jnp.arange(0, num_q))
|
||||
)
|
||||
res = rearrange(res, 'n b c h d -> b (n c) h d')
|
||||
return res
|
||||
|
||||
|
||||
class Carry(NamedTuple):
|
||||
numerator: jax.Array
|
||||
denominator: jax.Array
|
||||
max_so_far: jax.Array
|
||||
|
||||
|
||||
def _chunk_attention_bias(query_chunk_size, key_chunk_size,
|
||||
bias, deterministic, attn_dropout, attn_pdrop, causal,
|
||||
dtype, query_chunk_idx, key_chunk_idx):
|
||||
query_offset = query_chunk_idx * query_chunk_size
|
||||
key_offset = key_chunk_idx * key_chunk_size
|
||||
chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype)
|
||||
if bias is not None:
|
||||
chunk_bias = lax.dynamic_slice(
|
||||
bias,
|
||||
start_indices=(0, 0, query_offset, key_offset),
|
||||
slice_sizes=(*bias.shape[:2], min(bias.shape[-2], query_chunk_size), min(bias.shape[-1], key_chunk_size)),
|
||||
)
|
||||
|
||||
if causal:
|
||||
query_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(query_chunk_size, 1), dimension=0)
|
||||
key_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(1, key_chunk_size), dimension=1)
|
||||
offset = query_offset - key_offset
|
||||
query_idx += offset
|
||||
causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min
|
||||
chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape)
|
||||
|
||||
if not deterministic and attn_pdrop > 0.0:
|
||||
attn_dropout_slice = lax.dynamic_slice(
|
||||
attn_dropout,
|
||||
start_indices=(0, 0, query_offset, key_offset),
|
||||
slice_sizes=(
|
||||
*attn_dropout.shape[:2],
|
||||
min(attn_dropout.shape[-2], query_chunk_size),
|
||||
min(attn_dropout.shape[-1], key_chunk_size),
|
||||
),
|
||||
)
|
||||
chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min
|
||||
return chunk_bias.astype(dtype)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test
|
||||
def reference_attn(query, key, value, causal, dtype):
|
||||
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
|
||||
logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
|
||||
if causal:
|
||||
mask_value = jnp.finfo(logits.dtype).min
|
||||
_, q_seq_len, _, _ = query.shape
|
||||
_, kv_seq_len, _, _ = key.shape
|
||||
mask_shape = (q_seq_len, kv_seq_len)
|
||||
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
|
||||
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
|
||||
causal_mask = (row_ids < col_ids)[None, None, :, :]
|
||||
logits = logits + jnp.where(causal_mask, mask_value, 0.0)
|
||||
weights = jax.nn.softmax(logits, axis=-1)
|
||||
out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
|
||||
return out
|
||||
|
||||
# random inputs
|
||||
shape = (1, 32, 8, 64)
|
||||
query = jax.random.normal(jax.random.PRNGKey(0), shape)
|
||||
key = jax.random.normal(jax.random.PRNGKey(1), shape)
|
||||
value = jax.random.normal(jax.random.PRNGKey(2), shape)
|
||||
|
||||
causal = True
|
||||
chunk_size = 4
|
||||
policy = jax.checkpoint_policies.nothing_saveable()
|
||||
|
||||
blockwise = blockwise_attn(query, key, value, None, False, None, 0.0, causal, chunk_size, chunk_size, jnp.float32, policy, 'float32', True, False)
|
||||
reference = reference_attn(query, key, value, causal, 'float32')
|
||||
|
||||
assert jnp.allclose(reference, blockwise, atol=1e-6)
|
||||
Reference in New Issue
Block a user