First commit
This commit is contained in:
0
pkgs/xformers/_flash_attn/utils/__init__.py
Normal file
0
pkgs/xformers/_flash_attn/utils/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
146
pkgs/xformers/_flash_attn/utils/benchmark.py
Normal file
146
pkgs/xformers/_flash_attn/utils/benchmark.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# Copyright (c) 2022, Tri Dao.
|
||||
""" Useful functions for writing test code. """
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
|
||||
|
||||
def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, amp=False,
|
||||
amp_dtype=torch.float16, **kwinputs):
|
||||
""" Use Pytorch Benchmark on the forward pass of an arbitrary function. """
|
||||
if verbose:
|
||||
print(desc, '- Forward pass')
|
||||
def fn_amp(*inputs, **kwinputs):
|
||||
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
|
||||
fn(*inputs, **kwinputs)
|
||||
for _ in range(repeats): # warmup
|
||||
fn_amp(*inputs, **kwinputs)
|
||||
t = benchmark.Timer(
|
||||
stmt='fn_amp(*inputs, **kwinputs)',
|
||||
globals={'fn_amp': fn_amp, 'inputs': inputs, 'kwinputs': kwinputs},
|
||||
num_threads=torch.get_num_threads(),
|
||||
)
|
||||
m = t.timeit(repeats)
|
||||
if verbose:
|
||||
print(m)
|
||||
return t, m
|
||||
|
||||
|
||||
def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
|
||||
amp_dtype=torch.float16, **kwinputs):
|
||||
""" Use Pytorch Benchmark on the backward pass of an arbitrary function. """
|
||||
if verbose:
|
||||
print(desc, '- Backward pass')
|
||||
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
|
||||
y = fn(*inputs, **kwinputs)
|
||||
if type(y) is tuple:
|
||||
y = y[0]
|
||||
if grad is None:
|
||||
grad = torch.randn_like(y)
|
||||
else:
|
||||
if grad.shape != y.shape:
|
||||
raise RuntimeError('Grad shape does not match output shape')
|
||||
for _ in range(repeats): # warmup
|
||||
y.backward(grad, retain_graph=True)
|
||||
t = benchmark.Timer(
|
||||
stmt='y.backward(grad, retain_graph=True)',
|
||||
globals={'y': y, 'grad': grad},
|
||||
num_threads=torch.get_num_threads(),
|
||||
)
|
||||
m = t.timeit(repeats)
|
||||
if verbose:
|
||||
print(m)
|
||||
return t, m
|
||||
|
||||
|
||||
def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
|
||||
amp_dtype=torch.float16, **kwinputs):
|
||||
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
|
||||
if verbose:
|
||||
print(desc, '- Forward + Backward pass')
|
||||
def f(grad, *inputs, **kwinputs):
|
||||
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
|
||||
y = fn(*inputs, **kwinputs)
|
||||
if type(y) is tuple:
|
||||
y = y[0]
|
||||
if grad is None:
|
||||
grad = torch.randn_like(y)
|
||||
else:
|
||||
if grad.shape != y.shape:
|
||||
raise RuntimeError('Grad shape does not match output shape')
|
||||
y.backward(grad, retain_graph=True)
|
||||
for _ in range(repeats): # warmup
|
||||
f(grad, *inputs, **kwinputs)
|
||||
t = benchmark.Timer(
|
||||
stmt='f(grad, *inputs, **kwinputs)',
|
||||
globals={'f': f, 'fn': fn, 'inputs': inputs, 'grad': grad, 'kwinputs': kwinputs},
|
||||
num_threads=torch.get_num_threads(),
|
||||
)
|
||||
m = t.timeit(repeats)
|
||||
if verbose:
|
||||
print(m)
|
||||
return t, m
|
||||
|
||||
|
||||
def benchmark_all(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
|
||||
amp_dtype=torch.float16, **kwinputs):
|
||||
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
|
||||
return (
|
||||
benchmark_forward(fn, *inputs, repeats=repeats, desc=desc, verbose=verbose,
|
||||
amp=amp, amp_dtype=amp_dtype, **kwinputs),
|
||||
benchmark_backward(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose,
|
||||
amp=amp, amp_dtype=amp_dtype, **kwinputs),
|
||||
benchmark_combined(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose,
|
||||
amp=amp, amp_dtype=amp_dtype, **kwinputs),
|
||||
)
|
||||
|
||||
|
||||
def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False,
|
||||
amp_dtype=torch.float16, cpu=False, verbose=True, **kwinputs):
|
||||
""" Wrap benchmark functions in Pytorch profiler to see CUDA information. """
|
||||
if backward:
|
||||
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
|
||||
g = torch.randn_like(fn(*inputs, **kwinputs))
|
||||
for _ in range(30): # Warm up
|
||||
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
|
||||
if backward:
|
||||
for x in inputs:
|
||||
if isinstance(x, torch.Tensor):
|
||||
x.grad = None
|
||||
# fn(*inputs, **kwinputs) if not backward else fn(*inputs, **kwinputs).backward(g)
|
||||
out = fn(*inputs, **kwinputs)
|
||||
# Backward should be done outside autocast
|
||||
if backward:
|
||||
out.backward(g)
|
||||
activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [torch.profiler.ProfilerActivity.CUDA]
|
||||
with torch.profiler.profile(
|
||||
activities=activities,
|
||||
record_shapes=True,
|
||||
# profile_memory=True,
|
||||
with_stack=True,
|
||||
) as prof:
|
||||
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
|
||||
if backward:
|
||||
for x in inputs:
|
||||
if isinstance(x, torch.Tensor):
|
||||
x.grad = None
|
||||
out = fn(*inputs, **kwinputs)
|
||||
if backward: out.backward(g)
|
||||
if verbose:
|
||||
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
|
||||
print(prof.key_averages().table(row_limit=50))
|
||||
if trace_filename is not None:
|
||||
prof.export_chrome_trace(trace_filename)
|
||||
|
||||
|
||||
def benchmark_memory(fn, *inputs, desc='', verbose=True, **kwinputs):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
fn(*inputs, **kwinputs)
|
||||
torch.cuda.synchronize()
|
||||
mem = torch.cuda.max_memory_allocated() / ((2 ** 20) * 1000)
|
||||
if verbose:
|
||||
print(f'{desc} max memory: {mem}GB')
|
||||
torch.cuda.empty_cache()
|
||||
return mem
|
||||
127
pkgs/xformers/_flash_attn/utils/distributed.py
Normal file
127
pkgs/xformers/_flash_attn/utils/distributed.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
|
||||
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
|
||||
# version of PyTorch. The following 4 lines are for backward compatibility with
|
||||
# older PyTorch.
|
||||
if "all_gather_into_tensor" not in dir(torch.distributed):
|
||||
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
|
||||
if "reduce_scatter_tensor" not in dir(torch.distributed):
|
||||
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
|
||||
|
||||
|
||||
# Raw operation, does not support autograd, but does support async
|
||||
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
output = torch.empty(world_size * input_.shape[0], *input_.shape[1:],
|
||||
dtype=input_.dtype, device=input_.device)
|
||||
handle = torch.distributed.all_gather_into_tensor(output, input_.contiguous(),
|
||||
group=process_group, async_op=async_op)
|
||||
return output, handle
|
||||
|
||||
|
||||
# Raw operation, does not support autograd, but does support async
|
||||
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
assert input_.shape[0] % world_size == 0
|
||||
output = torch.empty(input_.shape[0] // world_size, *input_.shape[1:],
|
||||
dtype=input_.dtype, device=input_.device)
|
||||
handle = torch.distributed.reduce_scatter_tensor(output, input_.contiguous(),
|
||||
group=process_group,
|
||||
async_op=async_op)
|
||||
return output, handle
|
||||
|
||||
|
||||
# Raw operation, does not support autograd, but does support async
|
||||
def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
||||
input_ = input_.contiguous()
|
||||
handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
|
||||
return input_, handle
|
||||
|
||||
|
||||
class AllGatherFunc(torch.autograd.Function):
|
||||
"""Gather the input from sequence parallel region and concatenate."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
||||
ctx.process_group = process_group
|
||||
output, _ = all_gather_raw(input_, process_group)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output: Tensor):
|
||||
grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)
|
||||
return grad_input, None
|
||||
|
||||
|
||||
# Supports autograd, but does not support async
|
||||
all_gather = AllGatherFunc.apply
|
||||
|
||||
|
||||
class ReduceScatterFunc(torch.autograd.Function):
|
||||
"""Reduce scatter the input from the sequence parallel region and concatenate."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
||||
ctx.process_group = process_group
|
||||
output, _ = reduce_scatter_raw(input_, process_group)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output: Tensor):
|
||||
grad_input, _ = all_gather_raw(grad_output, ctx.process_group)
|
||||
return grad_input, None
|
||||
|
||||
|
||||
# Supports autograd, but does not support async
|
||||
reduce_scatter = ReduceScatterFunc.apply
|
||||
|
||||
|
||||
class AllReduceFunc(torch.autograd.Function):
|
||||
"""Gather the input from sequence parallel region and concatenate."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
||||
ctx.process_group = process_group
|
||||
output, _ = all_reduce_raw(input_, process_group)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output: Tensor):
|
||||
return grad_output, None
|
||||
|
||||
|
||||
# Supports autograd, but does not support async
|
||||
all_reduce = AllReduceFunc.apply
|
||||
|
||||
|
||||
def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
|
||||
# We want to iterate over parameters with _shared_params=True in the same order,
|
||||
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
|
||||
pamams_shared = {name: p for name, p in model.named_parameters()
|
||||
if getattr(p, '_shared_params', False)}
|
||||
for _, p in sorted(pamams_shared.items()):
|
||||
with torch.no_grad():
|
||||
# Broadcast needs src to be global rank, not group rank
|
||||
torch.distributed.broadcast(
|
||||
p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
|
||||
)
|
||||
|
||||
|
||||
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
|
||||
def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
|
||||
# We want to iterate over parameters with _sequence_parallel=True in the same order,
|
||||
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
|
||||
params_seqparallel = {name: p for name, p in model.named_parameters()
|
||||
if getattr(p, '_sequence_parallel', False)}
|
||||
grads = [p.grad for _, p in sorted(params_seqparallel.items())]
|
||||
if grads:
|
||||
with torch.no_grad():
|
||||
coalesced = torch._utils._flatten_dense_tensors(grads)
|
||||
torch.distributed.all_reduce(coalesced, group=process_group)
|
||||
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
302
pkgs/xformers/_flash_attn/utils/generation.py
Normal file
302
pkgs/xformers/_flash_attn/utils/generation.py
Normal file
@@ -0,0 +1,302 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
|
||||
from typing import Optional, Union, Sequence, Callable
|
||||
import gc
|
||||
import time
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.profiler import profile, record_function, ProfilerActivity
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceParams:
|
||||
"""Inference parameters that are passed to the main model in order
|
||||
to efficienly calculate and store the context during inference."""
|
||||
max_sequence_len: int
|
||||
max_batch_size: int
|
||||
sequence_len_offset: int = 0
|
||||
batch_size_offset: int = 0
|
||||
key_value_memory_dict: dict = field(default_factory=dict)
|
||||
fused_ft_kernel: bool = False
|
||||
lengths_per_sample: Optional[Tensor] = None
|
||||
|
||||
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
||||
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
|
||||
def modify_logits_for_top_p_filtering(logits, top_p):
|
||||
"""Set the logits for none top-p values to -inf."""
|
||||
if top_p <= 0.0:
|
||||
return
|
||||
# First sort and calculate cumulative sum of probabilities.
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
|
||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits = logits.masked_fill(indices_to_remove, float('-inf'))
|
||||
|
||||
|
||||
def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
|
||||
"""Sample from top-k logits.
|
||||
Arguments:
|
||||
logits: Tensor of shape (batch_size, vocab_size)
|
||||
"""
|
||||
if top_k == 1: # Short-circuit for greedy decoding
|
||||
return logits.argmax(dim=-1)
|
||||
else:
|
||||
if top_p > 0.0:
|
||||
assert top_p <= 1.0, 'top-p should be in (0, 1].'
|
||||
if top_k > 0:
|
||||
top_k = min(top_k, logits.size(-1)) # Safety check
|
||||
logits_top, indices = torch.topk(logits, top_k, dim=-1)
|
||||
logits_top /= temperature
|
||||
modify_logits_for_top_p_filtering(logits_top, top_p)
|
||||
return indices[
|
||||
torch.arange(indices.shape[0], device=indices.device),
|
||||
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
|
||||
]
|
||||
else:
|
||||
logits_top = logits / temperature
|
||||
modify_logits_for_top_p_filtering(logits_top, top_p)
|
||||
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
|
||||
|
||||
|
||||
def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
||||
eos_token_id=None, teacher_outputs=None, vocab_size=None, tensor_parallel=1,
|
||||
fused_ft_kernel=False, cg=False, timing=False):
|
||||
"""Decoding, either greedy or with top-k or top-p sampling.
|
||||
If top-k = 0, don't limit the number of candidates (pure sampling).
|
||||
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
|
||||
then top-p.
|
||||
We assume that all sequences in the same batch have the same length.
|
||||
|
||||
Arguments:
|
||||
input_ids: (batch, seq_len)
|
||||
max_length: int
|
||||
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
|
||||
logits, the next token is taken from the teacher_outputs. Useful for testing.
|
||||
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
|
||||
sequences: (batch, max_length)
|
||||
scores: tuples of (batch, vocab_size)
|
||||
"""
|
||||
batch_size, seqlen_og = input_ids.shape
|
||||
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
|
||||
if cg:
|
||||
assert fused_ft_kernel
|
||||
if not hasattr(model, '_decoding_cache'):
|
||||
model._decoding_cache = None
|
||||
model._decoding_cache = update_graph_cache(
|
||||
model, model._decoding_cache, batch_size, seqlen_og, max_length,
|
||||
tensor_parallel=tensor_parallel
|
||||
)
|
||||
inference_params = model._decoding_cache.inference_params
|
||||
inference_params.max_sequence_len = max_length
|
||||
inference_params.max_batch_size = batch_size
|
||||
inference_params.sequence_len_offset = 0
|
||||
else:
|
||||
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size,
|
||||
fused_ft_kernel=fused_ft_kernel)
|
||||
scores = []
|
||||
with torch.inference_mode():
|
||||
if timing:
|
||||
if tensor_parallel > 1:
|
||||
torch.distributed.barrier()
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
logits = model(input_ids, inference_params=inference_params, last_token_only=True).logits
|
||||
if vocab_size is not None:
|
||||
logits = logits[..., :vocab_size]
|
||||
scores.append(logits if not cg else logits.clone())
|
||||
if teacher_outputs is None or teacher_output_len <= seqlen_og:
|
||||
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
|
||||
else:
|
||||
next_token = teacher_outputs[:, seqlen_og]
|
||||
sequences = [next_token]
|
||||
inference_params.sequence_len_offset = seqlen_og
|
||||
while True:
|
||||
position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset,
|
||||
dtype=torch.long, device=input_ids.device)
|
||||
if not cg:
|
||||
logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
|
||||
inference_params=inference_params, last_token_only=True).logits
|
||||
else:
|
||||
logits = model._decoding_cache.run(rearrange(next_token, 'b -> b 1'), position_ids,
|
||||
inference_params.sequence_len_offset)
|
||||
if vocab_size is not None:
|
||||
logits = logits[..., :vocab_size]
|
||||
scores.append(logits if not cg else logits.clone())
|
||||
if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset + 1:
|
||||
next_token = sample(logits, top_k=top_k, temperature=temperature)
|
||||
else:
|
||||
next_token = teacher_outputs[:, inference_params.sequence_len_offset + 1]
|
||||
sequences.append(next_token)
|
||||
inference_params.sequence_len_offset += 1
|
||||
if eos_token_id is not None and (next_token == eos_token_id).all():
|
||||
break
|
||||
if inference_params.sequence_len_offset >= max_length - 1:
|
||||
break
|
||||
if timing:
|
||||
if tensor_parallel > 1:
|
||||
torch.distributed.barrier()
|
||||
torch.cuda.synchronize()
|
||||
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
|
||||
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
|
||||
return output_cls(
|
||||
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),
|
||||
scores=tuple(scores)
|
||||
)
|
||||
|
||||
|
||||
class GenerationMixin:
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def generate(self, input_ids, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
||||
return_dict_in_generate=False, output_scores=False, **kwargs):
|
||||
output = decode(input_ids, self, max_length, top_k=top_k, top_p=top_p,
|
||||
temperature=temperature, **kwargs)
|
||||
if not output_scores:
|
||||
output.scores = None
|
||||
return output if return_dict_in_generate else output.sequences
|
||||
|
||||
|
||||
def allocate_inference_cache(max_batch_size, max_seqlen, nheads, headdim, layers: Union[int, Sequence],
|
||||
device, dtype=torch.float16):
|
||||
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
packsize = 4 if dtype == torch.float32 else 8
|
||||
assert headdim % packsize == 0
|
||||
k_cache_shape = (max_batch_size, nheads, headdim // packsize, max_seqlen, packsize)
|
||||
v_cache_shape = (max_batch_size, nheads, max_seqlen, headdim)
|
||||
if isinstance(layers, int):
|
||||
layers = range(layers)
|
||||
return {i: (torch.empty(k_cache_shape, device=device, dtype=dtype),
|
||||
torch.empty(v_cache_shape, device=device, dtype=dtype))
|
||||
for i in layers}
|
||||
|
||||
|
||||
def seqlen_to_seqlen_type(seqlen: int) -> int:
|
||||
"""Convert sequence length to a seqlen_type.
|
||||
This is used to determine which cuda graph to use.
|
||||
Arguments:
|
||||
seqlen: int
|
||||
"""
|
||||
return 0 if seqlen < 32 else (1 if seqlen < 2048 else 2)
|
||||
|
||||
|
||||
def seqlen_type_to_max_seqlen(seqlen_type: int) -> int:
|
||||
assert seqlen_type in [0, 1, 2]
|
||||
return 32 if seqlen_type == 0 else (2048 if seqlen_type == 1 else 2**32)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodingCGCache:
|
||||
max_batch_size: int = 0
|
||||
max_seqlen: int = 0
|
||||
device = None
|
||||
dtype = None
|
||||
callables: dict = field(default_factory=dict)
|
||||
mempool = None
|
||||
inference_params: Optional[InferenceParams] = None
|
||||
run: Optional[Callable] = None
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1,
|
||||
dtype=None, n_warmups=2):
|
||||
if cache is None:
|
||||
cache = DecodingCGCache()
|
||||
param_example = next(iter(model.parameters()))
|
||||
device = param_example.device
|
||||
if dtype is None:
|
||||
dtype = param_example.dtype
|
||||
if ((device, dtype) != (cache.device, cache.dtype) or batch_size > cache.max_batch_size
|
||||
or max_seqlen > cache.max_seqlen): # Invalidate the cache
|
||||
cache.callables = {}
|
||||
cache.mempool = None
|
||||
cache.inference_params = None
|
||||
gc.collect()
|
||||
cache.device, cache.dtype = device, dtype
|
||||
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
|
||||
if hasattr(model, 'allocate_inference_cache'):
|
||||
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
|
||||
else:
|
||||
headdim = getattr(model.config, 'head_dim',
|
||||
model.config.hidden_size // model.config.num_attention_heads)
|
||||
inf_cache = allocate_inference_cache(
|
||||
batch_size, max_seqlen, model.config.num_attention_heads // tensor_parallel, headdim,
|
||||
model.config.num_hidden_layers, device, dtype
|
||||
)
|
||||
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
|
||||
cache.inference_params = InferenceParams(
|
||||
max_sequence_len=max_seqlen, max_batch_size=batch_size,
|
||||
sequence_len_offset=seqlen_og, key_value_memory_dict=inf_cache, fused_ft_kernel=True,
|
||||
lengths_per_sample=lengths_per_sample
|
||||
)
|
||||
cache.mempool = torch.cuda.graphs.graph_pool_handle()
|
||||
for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1):
|
||||
if (batch_size, s_type) not in cache.callables:
|
||||
max_seqlen_ = min(max(seqlen_og, seqlen_type_to_max_seqlen(s_type)), max_seqlen)
|
||||
cache.callables[batch_size, s_type] = capture_graph(
|
||||
model, cache.inference_params, batch_size, max_seqlen_, mempool=cache.mempool,
|
||||
n_warmups=n_warmups
|
||||
)
|
||||
|
||||
def dispatch(input_ids, position_ids, seqlen):
|
||||
batch_size = input_ids.shape[0]
|
||||
return cache.callables[batch_size, seqlen_to_seqlen_type(seqlen)](input_ids, position_ids, seqlen)
|
||||
|
||||
cache.run = dispatch
|
||||
cache.inference_params.sequence_len_offset = 0 # Reset so it's not confusing
|
||||
return cache
|
||||
|
||||
|
||||
def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, n_warmups=2):
|
||||
device = next(iter(model.parameters())).device
|
||||
input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
|
||||
position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
|
||||
sequence_len_offset_og = inference_params.sequence_len_offset
|
||||
# TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is
|
||||
# used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample.
|
||||
inference_params.sequence_len_offset = max_seqlen - 1
|
||||
inference_params.lengths_per_sample[:] = max_seqlen - 1
|
||||
|
||||
# Warmup before capture
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s):
|
||||
for _ in range(n_warmups):
|
||||
logits = model(input_ids, position_ids=position_ids, inference_params=inference_params,
|
||||
last_token_only=True).logits
|
||||
s.synchronize()
|
||||
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
|
||||
# which requires that graph launch and non-captured launch to not overlap (I think,
|
||||
# that's how I interpret the documentation). I'm not sure if this is required.
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.barrier()
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
# Captures the graph
|
||||
# To allow capture, automatically sets a side stream as the current stream in the context
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, pool=mempool):
|
||||
logits = model(input_ids, position_ids=position_ids, inference_params=inference_params,
|
||||
last_token_only=True).logits
|
||||
|
||||
def run(new_input_ids, new_position_ids, seqlen):
|
||||
inference_params.lengths_per_sample[:] = seqlen
|
||||
input_ids.copy_(new_input_ids)
|
||||
position_ids.copy_(new_position_ids)
|
||||
graph.replay()
|
||||
return logits
|
||||
|
||||
inference_params.sequence_len_offset = sequence_len_offset_og
|
||||
return run
|
||||
37
pkgs/xformers/_flash_attn/utils/pretrained.py
Normal file
37
pkgs/xformers/_flash_attn/utils/pretrained.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import torch
|
||||
|
||||
from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||
from transformers.utils import is_remote_url
|
||||
from transformers.modeling_utils import load_state_dict
|
||||
from transformers.utils.hub import cached_file, get_checkpoint_shard_files
|
||||
|
||||
|
||||
def state_dict_from_pretrained(model_name, device=None, dtype=None):
|
||||
# If not fp32, then we don't want to load directly to the GPU
|
||||
mapped_device = 'cpu' if dtype not in [torch.float32, None] else device
|
||||
is_sharded = False
|
||||
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
|
||||
_raise_exceptions_for_missing_entries=False)
|
||||
if resolved_archive_file is None:
|
||||
resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME,
|
||||
_raise_exceptions_for_missing_entries=False)
|
||||
if resolved_archive_file is not None:
|
||||
is_sharded = True
|
||||
if resolved_archive_file is None:
|
||||
raise EnvironmentError(f"Model name {model_name} was not found.")
|
||||
if is_sharded:
|
||||
# resolved_archive_file becomes a list of files that point to the different
|
||||
# checkpoint shards in this case.
|
||||
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
|
||||
model_name, resolved_archive_file
|
||||
)
|
||||
state_dict = {}
|
||||
for sharded_file in resolved_archive_file:
|
||||
state_dict.update(torch.load(sharded_file, map_location=mapped_device))
|
||||
else:
|
||||
state_dict = torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device)
|
||||
# Convert dtype before moving to GPU to save memory
|
||||
if dtype is not None:
|
||||
state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
|
||||
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
|
||||
return state_dict
|
||||
Reference in New Issue
Block a user