初始化项目,由ModelHub XC社区提供模型
Model: Finnish-NLP/Ahma-7B Source: Original Platform
This commit is contained in:
150
EasyLM/scripts/benchmark_attention.py
Normal file
150
EasyLM/scripts/benchmark_attention.py
Normal file
@@ -0,0 +1,150 @@
|
||||
from functools import partial
|
||||
from time import time
|
||||
import os
|
||||
import numpy as np
|
||||
import jax
|
||||
import jax.flatten_util
|
||||
import jax.numpy as jnp
|
||||
import mlxu
|
||||
from EasyLM.bpt import blockwise_attn
|
||||
from EasyLM.jax_utils import (
|
||||
get_float_dtype_by_name, set_random_seed, next_rng, JaxRNG
|
||||
)
|
||||
|
||||
|
||||
FLAGS, _ = mlxu.define_flags_with_default(
|
||||
seed=42,
|
||||
dtype='fp32',
|
||||
embed_dim=2048,
|
||||
n_heads=16,
|
||||
ref_attn_seq_len=2048,
|
||||
eff_attn_seq_len=16384,
|
||||
batch_size=1,
|
||||
query_chunk_size=2048,
|
||||
key_chunk_size=2048,
|
||||
warmup_steps=40,
|
||||
steps=200,
|
||||
)
|
||||
|
||||
|
||||
def main(argv):
|
||||
|
||||
def random_kqv(rng_key, seq_len):
|
||||
rng_generator = JaxRNG(rng_key)
|
||||
kqv = []
|
||||
for i in range(3):
|
||||
kqv.append(
|
||||
jax.random.normal(
|
||||
rng_generator(),
|
||||
(FLAGS.batch_size, seq_len, FLAGS.n_heads, FLAGS.embed_dim // FLAGS.n_heads),
|
||||
dtype=get_float_dtype_by_name(FLAGS.dtype)
|
||||
)
|
||||
)
|
||||
return tuple(kqv)
|
||||
|
||||
def reference_attn(query, key, value):
|
||||
dtype = get_float_dtype_by_name(FLAGS.dtype)
|
||||
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
|
||||
logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
|
||||
mask_value = jnp.finfo(logits.dtype).min
|
||||
_, q_seq_len, _, _ = query.shape
|
||||
_, kv_seq_len, _, _ = key.shape
|
||||
mask_shape = (q_seq_len, kv_seq_len)
|
||||
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
|
||||
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
|
||||
causal_mask = (row_ids < col_ids)[None, None, :, :]
|
||||
logits = logits + jnp.where(causal_mask, mask_value, 0.0)
|
||||
weights = jax.nn.softmax(logits, axis=-1)
|
||||
out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
|
||||
return out
|
||||
|
||||
def efficient_attention(query, key, value):
|
||||
dtype = get_float_dtype_by_name(FLAGS.dtype)
|
||||
return blockwise_attn(
|
||||
query, key, value,
|
||||
bias=None,
|
||||
deterministic=True,
|
||||
dropout_rng=None,
|
||||
attn_pdrop=0.0,
|
||||
causal=True,
|
||||
query_chunk_size=FLAGS.query_chunk_size,
|
||||
key_chunk_size=FLAGS.key_chunk_size,
|
||||
dtype=get_float_dtype_by_name(FLAGS.dtype),
|
||||
policy=jax.checkpoint_policies.nothing_saveable(),
|
||||
precision=None,
|
||||
float32_logits=True,
|
||||
prevent_cse=True,
|
||||
)
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(1,))
|
||||
def reference_attn_forward_backward(rng_key, seq_len):
|
||||
@partial(jax.grad, argnums=(0, 1, 2))
|
||||
@partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable())
|
||||
def grad_fn(query, key, value):
|
||||
out = reference_attn(query, key, value)
|
||||
return jnp.mean(out)
|
||||
|
||||
query, key, value = random_kqv(rng_key, seq_len)
|
||||
return jax.flatten_util.ravel_pytree(
|
||||
grad_fn(query, key, value)[1]
|
||||
)[0].mean()
|
||||
|
||||
@partial(jax.jit, static_argnums=(1,))
|
||||
def efficient_attn_forward_backward(rng_key, seq_len):
|
||||
@partial(jax.grad, argnums=(0, 1, 2))
|
||||
def grad_fn(query, key, value):
|
||||
out = efficient_attention(query, key, value)
|
||||
return jnp.mean(out)
|
||||
|
||||
query, key, value = random_kqv(rng_key, seq_len)
|
||||
return jax.flatten_util.ravel_pytree(
|
||||
grad_fn(query, key, value)[1]
|
||||
)[0].mean()
|
||||
|
||||
|
||||
set_random_seed(FLAGS.seed)
|
||||
|
||||
jax.block_until_ready(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
|
||||
jax.block_until_ready(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
|
||||
|
||||
all_results = []
|
||||
for i in range(FLAGS.warmup_steps):
|
||||
all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
|
||||
jax.block_until_ready(all_results)
|
||||
|
||||
start_time = time()
|
||||
all_results = []
|
||||
for i in range(FLAGS.steps):
|
||||
all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
|
||||
|
||||
jax.block_until_ready(all_results)
|
||||
elapsed_time_ref_attn = time() - start_time
|
||||
print(f'Reference attention: {elapsed_time_ref_attn:.3f} seconds')
|
||||
|
||||
|
||||
all_results = []
|
||||
for i in range(FLAGS.warmup_steps):
|
||||
all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
|
||||
jax.block_until_ready(all_results)
|
||||
|
||||
|
||||
start_time = time()
|
||||
all_results = []
|
||||
for i in range(FLAGS.steps):
|
||||
all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
|
||||
|
||||
jax.block_until_ready(all_results)
|
||||
elapsed_time_efficient_attn = time() - start_time
|
||||
print(f'Efficient attention: {elapsed_time_efficient_attn:.3f} seconds')
|
||||
|
||||
flops_ratio = (FLAGS.eff_attn_seq_len / FLAGS.ref_attn_seq_len) ** 2
|
||||
efficiency = elapsed_time_ref_attn / elapsed_time_efficient_attn * flops_ratio
|
||||
print(f'Efficiency: {efficiency:.3f}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
mlxu.run(main)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user