Files
xc-llm-ascend/vllm_ascend/attention.py
Mengqing Cao 7006835977 [attn] fix device of tensors in attention (#25)
### What this PR does / why we need it?
Fix device of tensors created in `AscendAttentionBackendImpl`.

While specifying device to cards except card-0, there'll cause an
**device conflict** because the tensors (such as `attn_mask`) will be
put on card-0 by default.

This pr creates these tensors on the correct card corresponding to the
input.

### Does this PR introduce _any_ user-facing change?
User could specify device with local rank by this pr, and a modify on
vLLM is also needed, will related to this pr when created.

### How was this patch tested?
This is tested by the following code locally. Will add a test case when
the modify in vLLM is also completed.
```python
from vllm import LLM, SamplingParams

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]

# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
# Create an LLM.
llm = LLM(model="~/.cache/modelscope/hub/Qwen/Qwen2___5-7B-Instruct", device="npu:1")

# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```

Signed-off-by: MengqingCao <cmq0113@163.com>
2025-02-10 19:20:29 +08:00

686 lines
28 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/vllm/attention/backends
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import torch
try:
import torch_npu # noqa: F401
except ImportError:
print("Failed to import torch_npu.")
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
CommonMetadataBuilder,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
if TYPE_CHECKING:
from vllm_ascend.model_runner import ModelInputForNPUBuilder
SHARE_MASK_TRIL_PREFIX_CACHE = None
SHARE_MASK_TRIL = None
class AscendAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "ASCEND"
@staticmethod
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
return AscendAttentionBackendImpl
@staticmethod
def get_metadata_cls() -> Type["AscendMetadata"]:
return AscendMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size, num_kv_heads * head_size)
@staticmethod
def swap_blocks(
src_kv_cache: List[torch.Tensor],
dst_kv_cache: List[torch.Tensor],
src_to_dst: torch.Tensor,
) -> None:
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
src_indices = src_to_dst[:, 0]
dst_indices = src_to_dst[:, 1]
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
dst_key_cache.device)
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
dst_key_cache.device)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
src_indices = src_to_dists[:, 0]
dst_indices = src_to_dists[:, 1]
for kv_cache in kv_caches:
key_caches = kv_cache[0]
value_caches = kv_cache[1]
key_caches[dst_indices] = key_caches[src_indices]
value_caches[dst_indices] = value_caches[src_indices]
@staticmethod
def get_builder_cls() -> Type["AscendMetadataBuilder"]:
return AscendMetadataBuilder
@classmethod
def make_metadata_builder(cls, *args, **kwargs) -> "AscendMetadataBuilder":
return cls.get_builder_cls()(*args, **kwargs)
class AscendPagedAttention(PagedAttention):
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_indices: torch.Tensor,
) -> None:
torch_npu.npu_scatter_nd_update_(key_cache, slot_indices, key)
torch_npu.npu_scatter_nd_update_(value_cache, slot_indices, value)
@dataclass
class AscendMetadata(AttentionMetadata, PagedAttentionMetadata):
"""Metadata for Ascendbackend.
* modified from XFormersbackend
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# FIXME: It is for flash attn.
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]] = None
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] = None
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor] = None
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor] = None
# Self-attention prefill/decode metadata cache
_cached_prefill_metadata: Optional["AscendMetadata"] = None
_cached_decode_metadata: Optional["AscendMetadata"] = None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
attn_mask: Optional[torch.Tensor] = None
pse_shift: Optional[torch.Tensor] = None
sparse_mode: int = 0
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_tables: Optional[torch.Tensor] = None
# slot_mapping: Optional[torch.Tensor] = None
@property
def prefill_metadata(self) -> Optional["AscendMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
# Recover cached prefill-phase attention
# metadata structure
return self._cached_prefill_metadata
assert ((self.seq_lens is not None)
or (self.encoder_seq_lens is not None))
assert ((self.seq_lens_tensor is not None)
or (self.encoder_seq_lens_tensor is not None))
# Compute some attn_metadata fields which default to None
query_start_loc = (None if self.query_start_loc is None else
self.query_start_loc[:self.num_prefills + 1])
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[:self.num_prefill_tokens])
seq_lens = (None if self.seq_lens is None else
self.seq_lens[:self.num_prefills])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:self.num_prefills])
seq_start_loc = (None if self.seq_start_loc is None else
self.seq_start_loc[:self.num_prefills + 1])
context_lens_tensor = (None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills])
block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills])
# Construct & cache prefill-phase attention metadata structure
self._cached_prefill_metadata = AscendMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=slot_mapping,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
max_encoder_seq_len=self.max_encoder_seq_len,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables,
enable_kv_scales_calculation=False)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["AscendMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
# Recover cached decode-phase attention
# metadata structure
return self._cached_decode_metadata
assert ((self.seq_lens_tensor is not None)
or (self.encoder_seq_lens_tensor is not None))
# Compute some attn_metadata fields which default to None
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[self.num_prefill_tokens:])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[self.num_prefills:])
block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:])
# Construct & cache decode-phase attention metadata structure
self._cached_decode_metadata = AscendMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
seq_lens_tensor=seq_lens_tensor,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
# Batch may be composed of prefill|decodes, adjust query start
# indices to refer to the start of decodes. E.g.
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
query_start_loc=(self.query_start_loc[self.num_prefills:] -
self.query_start_loc[self.num_prefills])
if self.query_start_loc is not None else None,
seq_start_loc=self.seq_start_loc[self.num_prefills:]
if self.seq_start_loc is not None else None,
context_lens_tensor=None,
block_tables=block_tables,
use_cuda_graph=self.use_cuda_graph,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
max_encoder_seq_len=self.max_encoder_seq_len,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables,
enable_kv_scales_calculation=False)
return self._cached_decode_metadata
class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
_metadata_cls = AscendMetadata
def __init__(self, input_builder: "ModelInputForNPUBuilder"):
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
def compute_npu_slot_indices(self, is_profile_run, slot_indices, seq_id,
seq_len, context_len, start_idx, block_size,
block_tables, max_query_len):
"""
compute slot indices
slot mapping in other backend of vllm stores slot indices,
which are indicates by `block_number * block_size + block_offset`
In Ascend backend, slot mapping stores [block_number, block_offset].
To distinguish this, slot_indices is used in this func
"""
if is_profile_run:
# During memory profiling, the block tables are not
# initialized yet. In this case, we just use a dummy
# slot mapping.
# In embeddings, the block tables are {seq_id: None}.
slot_indices.extend([[PAD_SLOT_ID, 0]] * seq_len)
return
# Mask the [0, start_idx) tokens of the prompt with
# [PAD_SLOT_ID, 0], where start_idx is max(0, seq_len -
# sliding_window). For example, if the prompt len is 10,
# sliding window is 8, and block size is 4, the first two
# tokens are masked and the slot mapping will be
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
padding_mask_len = max(0, start_idx - context_len)
slot_indices.extend([[PAD_SLOT_ID, 0]] * padding_mask_len)
range_start = max(start_idx, context_len)
range_end = seq_len
numel = range_end - range_start
block_table = block_tables[seq_id]
for i in range(range_start, range_end):
block_number = block_table[i // block_size]
block_offset = i % block_size
slot_indices.append([block_number, block_offset])
slot_indices.extend([[PAD_SLOT_ID, 0]] * (max_query_len - numel))
def _add_seq_group(
self, inter_data: "ModelInputForNPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables
max_query_len = max(
max(data.query_lens)
for data in self.input_builder.inter_data_list)
is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table: List[int] = []
prefix_cache_hit = any([
inter_data.prefix_cache_hit
for inter_data in self.input_builder.inter_data_list
])
if prefix_cache_hit:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
if block_tables is not None:
block_table = block_tables[seq_id]
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
if curr_sliding_window_block == 0:
block_table = block_tables[seq_id]
else:
block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.block_tables.append(block_table)
# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
context_len,
self.sliding_window)
self.compute_npu_slot_indices(is_profile_run, self.slot_mapping,
seq_id, seq_len, context_len,
start_idx, self.block_size,
inter_data.block_tables,
max_query_len)
class AscendAttentionBackendImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = sliding_window
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.attn_type = attn_type
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: List[torch.Tensor],
attn_metadata: AscendMetadata,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Ascend attention.
Args:
query: shape = [num_tokens, num_heads * head_size]
num_tokens = batch_size * seq_len
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache: shape = [2, num_blocks, block_size,
num_kv_heads * head_size]
key_cache = [num_blocks, block_size,
num_kv_heads * head_size]
value_cache = [num_blocks, block_size,
num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size, seq_len * num_heads * head_size]
"""
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
attn_type = self.attn_type
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
# view q k v to BSH
num_tokens = query.shape[0]
if kv_cache is not None and len(kv_cache) >= 2:
slot_indices = attn_metadata.slot_mapping
key_cache, value_cache = kv_cache[0], kv_cache[1]
AscendPagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
slot_indices,
)
if attn_metadata.num_prefills > 0:
if attn_metadata.attn_mask is None:
if num_tokens > 16384:
attn_metadata.sparse_mode = 2
attention_mask = gen_input_mask(
attn_metadata.max_prefill_seq_len, self.sliding_window,
num_tokens, query.device)
attn_metadata.attn_mask = attention_mask
if (self.alibi_slopes is not None
and attn_metadata.pse_shift is None):
attn_metadata.pse_shift = _make_alibi_bias(
self.alibi_slopes,
self.num_kv_heads,
dtype=query.dtype,
seq_len=attn_metadata.max_prefill_seq_len,
batch_size=num_tokens,
device=query.device,
)
if (len(kv_cache) == 0 or attn_metadata.block_tables is None
or attn_metadata.block_tables.numel() == 0):
max_seq_len = attn_metadata.max_prefill_seq_len
# shape of q/k/v [B,S*H] --> [B,S,N,D]
query = query.view(-1, max_seq_len, self.num_heads,
self.head_size).transpose(1, 2)
key = key.view(-1, max_seq_len, self.num_kv_heads,
self.head_size).transpose(1, 2)
value = value.view(-1, max_seq_len, self.num_kv_heads,
self.head_size).transpose(1, 2)
# FA for prefill phase
output = torch_npu.npu_prompt_flash_attention(
query,
key,
value,
pse_shift=attn_metadata.pse_shift,
atten_mask=attn_metadata.attn_mask,
num_heads=self.num_heads,
scale_value=1 / math.sqrt(self.head_size),
input_layout="BNSD",
num_key_value_heads=self.num_kv_heads,
pre_tokens=65535,
next_tokens=0,
sparse_mode=attn_metadata.sparse_mode,
)
# reshape to [B,H]
output = output.transpose(1, 2).reshape(
num_tokens, self.num_heads * self.head_size)
else:
# prefix-enabled attention
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support prefix caching")
assert attn_metadata.seq_lens is not None
assert kv_cache is not None
query = query.view(query.shape[0], -1,
self.num_heads * self.head_size)
output = torch.zeros(query.shape,
device=query.device,
dtype=query.dtype)
# TODO (Mengqing Cao): torch_npu.npu_incre_flash_attention
# support only when `S == 1`, OPTIMIZE ME when prefix caching
# is supported in torch-npu ops.
for i in range(query.shape[0]):
# FA for prefill phase
output[i] = torch_npu.npu_incre_flash_attention(
query[i].unsqueeze(0),
key_cache,
value_cache,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
scale_value=self.scale,
input_layout="BSH",
block_table=attn_metadata.block_tables,
block_size=key_cache.
shape[1], # max val of block_size == 512
actual_seq_lengths=attn_metadata.seq_lens,
)
# [B,S,H] --> [B,H]
output = output.squeeze(1)
elif attn_metadata.decode_metadata:
# FA for decoding phase
assert kv_cache is not None
# shape of query [B,S*H] --> [B,S,H]
query = query.view(
-1,
1,
self.head_size * self.num_heads,
)
output = torch_npu.npu_incre_flash_attention(
query,
key_cache,
value_cache,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
scale_value=self.scale,
input_layout="BSH",
block_table=attn_metadata.block_tables,
block_size=key_cache.shape[1], # max val of block_size == 512
actual_seq_lengths=attn_metadata.seq_lens,
)
# [B,S,H] --> [B,H]
output = output.squeeze(1)
return output
def gen_input_mask(seq_len, sliding_window, len, device):
"""
Generating lower triangular matrix
"""
if len > 16384:
# improve computing performance on NPU when input tokens are huge
global SHARE_MASK_TRIL_PREFIX_CACHE
if SHARE_MASK_TRIL_PREFIX_CACHE is None:
SHARE_MASK_TRIL_PREFIX_CACHE = torch.triu(
torch.ones(1, 1, 2048, 2048, dtype=bool, device=device),
diagonal=1,
)
attention_mask = SHARE_MASK_TRIL_PREFIX_CACHE
else:
global SHARE_MASK_TRIL
if SHARE_MASK_TRIL is None or SHARE_MASK_TRIL.shape[0] < seq_len:
SHARE_MASK_TRIL = ~torch.tril(
torch.ones(seq_len, seq_len, dtype=bool, device=device))
attention_mask = SHARE_MASK_TRIL
if sliding_window is not None:
attention_mask = ~attention_mask
attention_mask = torch.triu(attention_mask,
diagonal=1 - sliding_window)
attention_mask = ~attention_mask
return attention_mask
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
num_kv_heads: int,
dtype: torch.dtype,
seq_len: int,
batch_size: int,
device: torch.device,
):
alibi_slopes = alibi_slopes.to(device)
bias = torch.arange(seq_len, dtype=dtype, device=device)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
# Calculate a matrix where each element represents ith element- jth
# element.
bias = bias[None, :] - bias[:, None]
padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty(
1,
num_heads,
seq_len,
padded_len,
device=device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
return bias