819 lines
35 KiB
Python
819 lines
35 KiB
Python
import torch
|
|
|
|
from contextlib import contextmanager
|
|
from itertools import accumulate
|
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
|
|
|
import torch
|
|
|
|
from vllm import _mlu_ops as mlu_ops
|
|
from vllm.attention.backends.abstract import (AttentionMetadata,
|
|
AttentionType)
|
|
from vllm.attention.backends.utils import (
|
|
PAD_SLOT_ID, get_num_prefill_decode_query_kv_tokens,
|
|
get_seq_len_block_table_args)
|
|
from vllm.forward_context import get_forward_context
|
|
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
|
|
make_tensor_with_pad)
|
|
from vllm.attention.backends.mlu_attn import (
|
|
MLUFlashAttentionBackend, MLUFlashAttentionMetadataBuilder,
|
|
MLUFlashAttentionMetadata, MLUFlashAttentionImpl,
|
|
MLUFlashAttentionState, _get_query_key_seq_metadata,
|
|
_get_causal_option)
|
|
from vllm_mlu._mlu_utils import USE_PAGED
|
|
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
|
|
|
|
|
class MLUFlashAttentionBackend_V2(MLUFlashAttentionBackend):
|
|
|
|
@staticmethod
|
|
def get_supported_head_sizes() -> List[int]:
|
|
return [32, 64, 80, 96, 128, 160, 192, 224, 256, 512, 576]
|
|
|
|
@staticmethod
|
|
def get_impl_cls() -> Type["MLUFlashAttentionImpl_V2"]:
|
|
return MLUFlashAttentionImpl_V2
|
|
|
|
@staticmethod
|
|
def get_builder_cls() -> Type["MLUFlashAttentionMetadataBuilder_V2"]:
|
|
return MLUFlashAttentionMetadataBuilder_V2
|
|
|
|
@staticmethod
|
|
def get_state_cls() -> Type["MLUFlashAttentionState_V2"]:
|
|
return MLUFlashAttentionState_V2
|
|
|
|
@staticmethod
|
|
def get_kv_cache_scale_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int,
|
|
) -> Tuple[int, ...]:
|
|
return (2, num_blocks, num_kv_heads, block_size)
|
|
|
|
@staticmethod
|
|
def copy_blocks(
|
|
kv_caches: List[List[torch.Tensor]],
|
|
src_to_dists: torch.Tensor,
|
|
) -> None:
|
|
key_caches = [kv_cache[0][0] for kv_cache in kv_caches]
|
|
value_caches = [kv_cache[0][1] for kv_cache in kv_caches]
|
|
mlu_ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
|
|
|
kv_cache_scales = [kv_cache[1] for kv_cache in kv_caches]
|
|
if len(kv_cache_scales) > 0 and kv_cache_scales[0].numel() > 0:
|
|
key_cache_scales = [kv_cache_scale[0] for kv_cache_scale in kv_cache_scales]
|
|
value_cache_scales = [kv_cache_scale[1] for kv_cache_scale in kv_cache_scales]
|
|
mlu_ops.copy_blocks(key_cache_scales, value_cache_scales, src_to_dists)
|
|
|
|
|
|
class MLUMLAFlashAttentionBackend(MLUFlashAttentionBackend):
|
|
|
|
@staticmethod
|
|
def get_kv_cache_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
) -> Tuple[int, ...]:
|
|
return (1, num_blocks, num_kv_heads, block_size, head_size)
|
|
|
|
@staticmethod
|
|
def swap_blocks(
|
|
src_kv_cache: torch.Tensor,
|
|
dst_kv_cache: torch.Tensor,
|
|
src_to_dst: torch.Tensor,
|
|
) -> None:
|
|
src_key_cache = src_kv_cache[0]
|
|
dst_key_cache = dst_kv_cache[0]
|
|
mlu_ops.swap_blocks(dst_key_cache, src_key_cache, src_to_dst)
|
|
|
|
@staticmethod
|
|
def get_kv_cache_scale_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int,
|
|
) -> Tuple[int, ...]:
|
|
return (1, num_blocks, num_kv_heads, block_size)
|
|
|
|
@staticmethod
|
|
def copy_blocks(
|
|
kv_caches: List[List[torch.Tensor]],
|
|
src_to_dists: torch.Tensor,
|
|
) -> None:
|
|
key_caches = [kv_cache[0][0] for kv_cache in kv_caches]
|
|
mlu_ops.copy_blocks(key_caches, None, src_to_dists)
|
|
|
|
kv_cache_scales = [kv_cache[1] for kv_cache in kv_caches]
|
|
if len(kv_cache_scales) > 0 and kv_cache_scales[0].numel() > 0:
|
|
key_cache_scales = [kv_cache_scale[0] for kv_cache_scale in kv_cache_scales]
|
|
mlu_ops.copy_blocks(key_cache_scales, None, src_to_dists)
|
|
|
|
|
|
class MLUFlashAttentionState_V2(MLUFlashAttentionState):
|
|
|
|
def __init__(self, runner: "ModelRunnerBase"):
|
|
MLUFlashAttentionState.__init__(self, runner)
|
|
|
|
@contextmanager
|
|
def graph_capture(self, max_batch_size: int):
|
|
self._is_graph_capturing = True
|
|
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
|
PAD_SLOT_ID if USE_PAGED else 0,
|
|
dtype=torch.int32,
|
|
device=self.runner.device)
|
|
self._graph_seq_lens = torch.ones(max_batch_size,
|
|
dtype=torch.int32,
|
|
device=self.runner.device)
|
|
self._graph_block_tables = torch.from_numpy(
|
|
self.runner.graph_block_tables).to(device=self.runner.device)
|
|
yield
|
|
self._is_graph_capturing = False
|
|
del self._graph_slot_mapping
|
|
del self._graph_seq_lens
|
|
del self._graph_block_tables
|
|
|
|
def graph_capture_get_metadata_for_batch(
|
|
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
|
assert self._is_graph_capturing
|
|
attn_metadata = self.runner.attn_backend.make_metadata(
|
|
num_prefills=0,
|
|
num_prefill_tokens=0,
|
|
num_decode_tokens=batch_size,
|
|
slot_mapping=self._graph_slot_mapping[:batch_size],
|
|
multi_modal_placeholder_index_maps=None,
|
|
seq_lens=None,
|
|
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
|
max_query_len=1,
|
|
max_decode_query_len=1,
|
|
max_prefill_seq_len=0,
|
|
max_decode_seq_len=(self.runner.max_seq_len_to_capture if USE_PAGED
|
|
else min(self.runner.block_size, self.runner.max_seq_len_to_capture)),
|
|
query_start_loc=None,
|
|
seq_start_loc=None,
|
|
context_lens_tensor=None,
|
|
block_tables=self._graph_block_tables[:batch_size],
|
|
use_cuda_graph=True,
|
|
)
|
|
if is_encoder_decoder_model:
|
|
# The encoder decoder model works only with XFormers and
|
|
# Flash Attention backend. Assert the same.
|
|
assert self.runner.attn_backend.get_name() in\
|
|
["XFORMERS", "FLASH_ATTN"], \
|
|
f"Expected attn_backend name to be either 'XFORMERS' or " \
|
|
f"'FLASH_ATTN', but "\
|
|
f"got '{self.runner.attn_backend.get_name()}'"
|
|
self._update_captured_metadata_for_enc_dec_model(
|
|
batch_size=batch_size, attn_metadata=attn_metadata)
|
|
|
|
return attn_metadata
|
|
|
|
def get_graph_input_buffers(
|
|
self,
|
|
attn_metadata: AttentionMetadata,
|
|
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
|
|
input_buffers = {
|
|
"slot_mapping": attn_metadata.slot_mapping,
|
|
"seq_lens_tensor": None,
|
|
"block_tables": None,
|
|
}
|
|
if attn_metadata.num_prefills > 0:
|
|
input_buffers["seq_lens_tensor"] = attn_metadata.prefill_metadata.seq_lens_tensor
|
|
input_buffers["block_tables"] = attn_metadata.prefill_metadata.block_tables
|
|
else:
|
|
input_buffers["seq_lens_tensor"] = attn_metadata.decode_metadata.seq_lens_tensor
|
|
input_buffers["block_tables"] = attn_metadata.decode_metadata.block_tables
|
|
if is_encoder_decoder_model:
|
|
# The encoder decoder model works only with XFormers and
|
|
# Flash Attention backend. Assert the same.
|
|
assert self.runner.attn_backend.get_name() in\
|
|
["XFORMERS", "FLASH_ATTN"], \
|
|
f"Expected attn_backend name to be either 'XFORMERS' or "\
|
|
f"'FLASH_ATTN', but "\
|
|
f"got '{self.runner.attn_backend.get_name()}'"
|
|
self._add_additonal_input_buffers_for_enc_dec_model(
|
|
attn_metadata=attn_metadata, input_buffers=input_buffers)
|
|
return input_buffers
|
|
|
|
def prepare_graph_input_buffers(
|
|
self,
|
|
input_buffers: Dict[str, Any],
|
|
attn_metadata: AttentionMetadata,
|
|
is_encoder_decoder_model: bool = False) -> None:
|
|
metadata = attn_metadata.prefill_metadata if \
|
|
attn_metadata.num_prefills > 0 else attn_metadata.decode_metadata
|
|
|
|
input_buffers["seq_lens_tensor"].copy_(
|
|
metadata.seq_lens_tensor, non_blocking=True)
|
|
input_buffers["block_tables"].copy_(
|
|
metadata.block_tables, non_blocking=True)
|
|
if is_encoder_decoder_model:
|
|
# The encoder decoder model works only with XFormers and
|
|
# Flash Attention backend. Assert the same.
|
|
assert self.runner.attn_backend.get_name() in\
|
|
["XFORMERS", "FLASH_ATTN"], \
|
|
f"Expected attn_backend name to be either 'XFORMERS' or "\
|
|
f"'FLASH_ATTN', but "\
|
|
f"got '{self.runner.attn_backend.get_name()}'"
|
|
self._prepare_input_buffers_for_enc_dec_model(
|
|
attn_metadata, input_buffers)
|
|
|
|
@contextmanager
|
|
def graph_capture_with_context(
|
|
self,
|
|
ctx_graph_batch_size: int,
|
|
max_batch_size: int,
|
|
max_num_tokens: int
|
|
):
|
|
self._is_graph_capturing = True
|
|
self._graph_slot_mapping = torch.full((max_num_tokens, ),
|
|
PAD_SLOT_ID if USE_PAGED else 0,
|
|
dtype=torch.int32,
|
|
device=self.runner.device)
|
|
self._graph_seq_lens = torch.ones(max_batch_size,
|
|
dtype=torch.int32,
|
|
device=self.runner.device)
|
|
# block tables used for decode mlugraph input buffer
|
|
self._graph_block_tables = torch.from_numpy(
|
|
self.runner.graph_block_tables).to(device=self.runner.device)
|
|
# block tables used for context mlugraph input buffer
|
|
self._ctx_graph_block_tables = torch.zeros((ctx_graph_batch_size, 0),
|
|
dtype=self._graph_block_tables.dtype,
|
|
device=self.runner.device)
|
|
yield
|
|
self._is_graph_capturing = False
|
|
del self._graph_slot_mapping
|
|
del self._graph_seq_lens
|
|
del self._graph_block_tables
|
|
del self._ctx_graph_block_tables
|
|
|
|
def fill_seq_lens_tensor(
|
|
self,
|
|
seq_len: int
|
|
) -> None:
|
|
self._graph_seq_lens.fill_(seq_len)
|
|
|
|
def graph_capture_get_metadata_for_context(
|
|
self,
|
|
batch_size: int,
|
|
seq_len: int,
|
|
is_encoder_decoder_model: bool = False
|
|
) -> MLUFlashAttentionMetadata:
|
|
assert self._is_graph_capturing
|
|
|
|
query_start_loc = torch.zeros(batch_size + 1,
|
|
dtype=torch.int32,
|
|
device=self.runner.device)
|
|
seq_start_loc = torch.zeros(batch_size + 1,
|
|
dtype=torch.int32,
|
|
device=self.runner.device)
|
|
context_lens_tensor = torch.zeros(batch_size,
|
|
dtype=torch.int32,
|
|
device=self.runner.device)
|
|
torch.cumsum(self._graph_seq_lens[:batch_size],
|
|
dim=0,
|
|
dtype=query_start_loc.dtype,
|
|
out=query_start_loc[1:])
|
|
torch.cumsum(self._graph_seq_lens[:batch_size],
|
|
dim=0,
|
|
dtype=seq_start_loc.dtype,
|
|
out=seq_start_loc[1:])
|
|
|
|
num_tokens = batch_size * seq_len
|
|
attn_metadata = self.runner.attn_backend.make_metadata(
|
|
num_prefills=batch_size,
|
|
num_prefill_tokens=num_tokens,
|
|
num_decode_tokens=0,
|
|
slot_mapping=self._graph_slot_mapping[:num_tokens],
|
|
multi_modal_placeholder_index_maps=None,
|
|
seq_lens=[seq_len] * batch_size,
|
|
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
|
max_query_len=seq_len,
|
|
max_decode_query_len=0,
|
|
max_prefill_seq_len=seq_len,
|
|
max_decode_seq_len=0,
|
|
query_start_loc=query_start_loc,
|
|
seq_start_loc=seq_start_loc,
|
|
context_lens_tensor=context_lens_tensor,
|
|
block_tables=self._ctx_graph_block_tables,
|
|
use_cuda_graph=True,
|
|
)
|
|
return attn_metadata
|
|
|
|
|
|
class MLUFlashAttentionMetadataBuilder_V2(MLUFlashAttentionMetadataBuilder):
|
|
|
|
def build(
|
|
self,
|
|
seq_lens: List[int],
|
|
query_lens: List[int],
|
|
cuda_graph_pad_size: int,
|
|
batch_size: int
|
|
) -> MLUFlashAttentionMetadata:
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: Use origin func if do not use context mlugraph.
|
|
'''
|
|
if not self.runner.model_config.use_context_mlugraph():
|
|
return super().build(seq_lens,
|
|
query_lens,
|
|
cuda_graph_pad_size,
|
|
batch_size)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
prefix_cache_hit = any([
|
|
inter_data.prefix_cache_hit
|
|
for inter_data in self.input_builder.inter_data_list
|
|
])
|
|
for inter_data in self.input_builder.inter_data_list:
|
|
self._add_seq_group(inter_data,
|
|
self.input_builder.chunked_prefill_enabled,
|
|
prefix_cache_hit)
|
|
|
|
device = self.runner.device
|
|
use_captured_graph = cuda_graph_pad_size != -1
|
|
|
|
max_query_len = max(query_lens)
|
|
decode_query_lens = query_lens[self.num_prefills:]
|
|
if len(decode_query_lens) > 0:
|
|
max_decode_query_len = max(decode_query_lens)
|
|
else:
|
|
max_decode_query_len = 1
|
|
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
|
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
|
num_decode_tokens = self.num_decode_tokens
|
|
query_start_loc = list(accumulate(query_lens, initial=0))
|
|
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
|
|
|
num_seqs = len(seq_lens)
|
|
if use_captured_graph:
|
|
self.slot_mapping.extend([
|
|
(PAD_SLOT_ID if USE_PAGED else 0)] * cuda_graph_pad_size)
|
|
self.block_tables.extend([] * cuda_graph_pad_size)
|
|
num_decode_tokens = batch_size - self.num_prefill_tokens
|
|
block_tables = self._get_graph_runner_block_tables(
|
|
num_seqs, self.block_tables)
|
|
else:
|
|
if USE_PAGED:
|
|
block_tables = make_tensor_with_pad(
|
|
self.block_tables,
|
|
pad=0,
|
|
dtype=torch.int,
|
|
device=device,
|
|
)
|
|
else:
|
|
block_tables = make_tensor_without_pad(
|
|
self.block_tables,
|
|
dtype=torch.int,
|
|
device=device,
|
|
)
|
|
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
|
|
|
assert device is not None
|
|
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
|
device, self.runner.pin_memory)
|
|
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
|
self.runner.pin_memory)
|
|
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int32,
|
|
device, self.runner.pin_memory)
|
|
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
|
device,
|
|
self.runner.pin_memory)
|
|
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
|
device, self.runner.pin_memory)
|
|
placeholder_index_maps = {
|
|
modality: placeholder_map.index_map()
|
|
for modality, placeholder_map in
|
|
self.multimodal_placeholder_maps.items()
|
|
}
|
|
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: Check if we can use context mlugraph for the given input.
|
|
'''
|
|
if num_decode_tokens == 0 and self.num_prefills > 0:
|
|
ctx_graph_bs, ctx_graph_seq_len = (
|
|
self.runner.model_config.get_context_mlugraph_bs_and_seq())
|
|
use_captured_graph = len(seq_lens) == ctx_graph_bs and all(
|
|
seq_len == ctx_graph_seq_len for seq_len in seq_lens)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
return MLUFlashAttentionMetadata(
|
|
num_prefills=self.num_prefills,
|
|
slot_mapping=slot_mapping_tensor,
|
|
num_prefill_tokens=self.num_prefill_tokens,
|
|
num_decode_tokens=num_decode_tokens,
|
|
seq_lens=seq_lens,
|
|
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
|
seq_lens_tensor=seq_lens_tensor,
|
|
max_query_len=max_query_len,
|
|
max_decode_query_len=max_decode_query_len,
|
|
max_prefill_seq_len=max_prefill_seq_len,
|
|
max_decode_seq_len=max_decode_seq_len,
|
|
query_start_loc=query_start_loc_tensor,
|
|
seq_start_loc=seq_start_loc_tensor,
|
|
context_lens_tensor=context_lens_tensor,
|
|
block_tables=block_tables,
|
|
use_cuda_graph=use_captured_graph,
|
|
)
|
|
|
|
|
|
class MLUFlashAttentionImpl_V2(MLUFlashAttentionImpl):
|
|
|
|
def forward(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
kv_cache: List[torch.Tensor],
|
|
attn_metadata: MLUFlashAttentionMetadata,
|
|
k_scale: float = 1.0,
|
|
v_scale: float = 1.0,
|
|
attn_type: AttentionType = AttentionType.DECODER,
|
|
use_mla: bool = False,
|
|
) -> torch.Tensor:
|
|
"""Forward pass with FlashAttention.
|
|
|
|
Args:
|
|
query: shape = [num_tokens, num_heads * head_size]
|
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
|
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
|
attn_metadata: Metadata for attention.
|
|
Returns:
|
|
shape = [num_tokens, num_heads * head_size]
|
|
"""
|
|
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
|
assert k_scale == 1.0 and v_scale == 1.0, (
|
|
"key/v_scale is not supported in FlashAttention.")
|
|
|
|
if (attn_type == AttentionType.ENCODER
|
|
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
|
raise AttributeError("Encoder attention requires setting "
|
|
"encoder metadata attributes.")
|
|
elif (attn_type == AttentionType.ENCODER_DECODER
|
|
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
|
raise AttributeError("Encoder/decoder cross-attention "
|
|
"requires setting cross-attention "
|
|
"metadata attributes.")
|
|
|
|
output = torch.ops.vllm.unified_flash_attention_v2(
|
|
query,
|
|
key,
|
|
value,
|
|
self.num_heads,
|
|
self.head_size,
|
|
self.num_kv_heads,
|
|
kv_cache,
|
|
self.kv_cache_dtype,
|
|
k_scale,
|
|
v_scale,
|
|
self.scale,
|
|
attn_type.value,
|
|
self.sliding_window,
|
|
self.alibi_slopes,
|
|
self.logits_soft_cap,
|
|
use_mla,
|
|
)
|
|
|
|
return output
|
|
|
|
|
|
def unified_flash_attention_v2(
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
num_heads: int,
|
|
head_size: int,
|
|
num_kv_heads: int,
|
|
kv_cache: List[torch.Tensor],
|
|
kv_cache_dtype: str,
|
|
k_scale: float,
|
|
v_scale: float,
|
|
softmax_scale: float,
|
|
attn_type_int_val: int,
|
|
window_size: Optional[List[int]] = None,
|
|
alibi_slopes: Optional[torch.Tensor] = None,
|
|
logits_soft_cap: Optional[float] = None,
|
|
use_mla: bool = False,
|
|
) -> torch.Tensor:
|
|
# Convert integer attn_type to enum
|
|
try:
|
|
attn_type = AttentionType(attn_type_int_val)
|
|
except ValueError as err:
|
|
raise AttributeError(
|
|
f"Invalid attention type {str(attn_type_int_val)}") from err
|
|
|
|
current_metadata = get_forward_context()
|
|
assert current_metadata is not None
|
|
assert isinstance(current_metadata, MLUFlashAttentionMetadata)
|
|
attn_metadata: MLUFlashAttentionMetadata = current_metadata
|
|
|
|
num_tokens, hidden_size = query.shape
|
|
|
|
# Reshape the query, key, and value tensors.
|
|
query = query.view(-1, num_heads, head_size)
|
|
v_head_size = value.size(1) // num_kv_heads
|
|
if (key is not None) and (value is not None):
|
|
key = key.view(-1, num_kv_heads, head_size)
|
|
if use_mla and attn_metadata.prefill_metadata:
|
|
value = value.view(-1, num_kv_heads, v_head_size)
|
|
else:
|
|
value = value.view(-1, num_kv_heads, head_size)
|
|
|
|
if kv_cache[0].numel() > 0:
|
|
kv_cache_, kv_cache_scale_ = kv_cache
|
|
key_cache = kv_cache_[0]
|
|
value_cache = None if use_mla else kv_cache_[1]
|
|
key_cache_scale, value_cache_scale = None, None
|
|
if kv_cache_scale_.numel() > 0:
|
|
key_cache_scale = kv_cache_scale_[0]
|
|
value_cache_scale = None if use_mla else kv_cache_scale_[1]
|
|
|
|
# We skip updating the KV cache under two conditions:
|
|
# a. When the Attention Type is ENCODER. In this phase, we compute
|
|
# only the encoder attention without updating the cache.
|
|
# b. When both Key and Value are None. This occurs during
|
|
# cross-attention computation in the decoding phase, where the KV
|
|
# cache is already populated with the cross-attention tensor.
|
|
# Thus, we skip cache updates during this time.
|
|
if (attn_type != AttentionType.ENCODER) and (key is not None) and (
|
|
value is not None):
|
|
if attn_type == AttentionType.ENCODER_DECODER:
|
|
# Update cross-attention KV cache (prefill-only)
|
|
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
|
else:
|
|
# Update self-attention KV cache (prefill/decode)
|
|
updated_slot_mapping = attn_metadata.slot_mapping
|
|
|
|
# Reshape the input keys and values and store them in the cache.
|
|
# If kv_cache is not provided, the new key and value tensors are
|
|
# not cached. This happens during the initial memory profiling run.
|
|
if USE_PAGED:
|
|
value_to_cache = None if use_mla else value
|
|
if use_mla and attn_metadata.prefill_metadata:
|
|
# MLA save cache info in models before flashattn
|
|
pass
|
|
else:
|
|
if kv_cache_dtype == 'int8':
|
|
mlu_ops.quant_to_paged_cache(key,
|
|
value_to_cache,
|
|
key_cache,
|
|
value_cache,
|
|
key_cache_scale,
|
|
value_cache_scale,
|
|
updated_slot_mapping.flatten())
|
|
else:
|
|
mlu_ops.reshape_paged_cache(key,
|
|
value_to_cache,
|
|
key_cache,
|
|
value_cache,
|
|
updated_slot_mapping.flatten())
|
|
else:
|
|
# unpaged (linear cache) path
|
|
if use_mla:
|
|
# MLA: 镜像 paged 路径的处理方式
|
|
# key_cache: (num_blocks, 1, block_size, 576)
|
|
value_to_cache = None
|
|
if attn_metadata.prefill_metadata:
|
|
# MLA prefill cache 已在 forward_prefill 中写入,跳过
|
|
pass
|
|
else:
|
|
if kv_cache_dtype == 'int8':
|
|
mlu_ops.quant_to_paged_cache(
|
|
key, value_to_cache,
|
|
key_cache, value_cache,
|
|
key_cache_scale, value_cache_scale,
|
|
updated_slot_mapping.flatten())
|
|
else:
|
|
mlu_ops.reshape_paged_cache(
|
|
key, value_to_cache,
|
|
key_cache, value_cache,
|
|
updated_slot_mapping.flatten())
|
|
else:
|
|
# FIXME: After TMO-1496 is completed, remove this code.
|
|
if key.stride() != value.stride():
|
|
key = key.contiguous()
|
|
value = value.contiguous()
|
|
if kv_cache_dtype == 'int8':
|
|
mlu_ops.quant_to_linear_cache(
|
|
key, value,
|
|
key_cache, value_cache,
|
|
key_cache_scale, value_cache_scale,
|
|
attn_metadata.cu_seq_lens,
|
|
attn_metadata.max_seq_len,
|
|
True, None,
|
|
attn_metadata.batch_ids,
|
|
attn_metadata.slot_mapping_unpaged)
|
|
else:
|
|
mlu_ops.reshape_linear_cache(
|
|
key, value,
|
|
key_cache, value_cache,
|
|
attn_metadata.cu_seq_lens,
|
|
attn_metadata.max_seq_len,
|
|
True, None,
|
|
attn_metadata.batch_ids,
|
|
attn_metadata.slot_mapping_unpaged)
|
|
if use_mla and attn_metadata.prefill_metadata:
|
|
output = torch.empty(query.shape[0], query.shape[1], v_head_size, dtype=query.dtype, device="mlu")
|
|
else:
|
|
output = torch.empty_like(query)
|
|
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
|
num_decode_query_tokens) = \
|
|
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
|
decode_query = query[num_prefill_query_tokens:]
|
|
# QKV for prefill.
|
|
query = query[:num_prefill_query_tokens]
|
|
assert query.shape[0] == num_prefill_query_tokens
|
|
assert decode_query.shape[0] == num_decode_query_tokens
|
|
|
|
if prefill_meta := attn_metadata.prefill_metadata:
|
|
alibi_slopes = None if alibi_slopes is None else \
|
|
alibi_slopes.repeat(attn_metadata.num_prefills, 1)
|
|
# Prompt run.
|
|
if (kv_cache[0].numel() == 0 or prefill_meta.block_tables is None
|
|
or prefill_meta.block_tables.numel() == 0):
|
|
# normal attention
|
|
# When block_tables are not filled, it means q and k are the
|
|
# prompt, and they have the same length.
|
|
q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
|
|
_get_query_key_seq_metadata(prefill_meta, True, attn_type)
|
|
|
|
key = key[:num_prefill_kv_tokens]
|
|
value = value[:num_prefill_kv_tokens]
|
|
mlu_ops.flash_attention(query,
|
|
key,
|
|
value,
|
|
output[:num_prefill_query_tokens],
|
|
q_seq_start_loc,
|
|
k_seq_start_loc,
|
|
alibi_slopes,
|
|
None,
|
|
q_seq_len,
|
|
k_seq_len,
|
|
softmax_scale,
|
|
_get_causal_option(attn_type),
|
|
-1 if window_size is None \
|
|
else window_size[0],
|
|
-1 if window_size is None \
|
|
else window_size[1],
|
|
torch.float,
|
|
False)
|
|
else:
|
|
# prefix-enabled attention
|
|
assert attn_type == AttentionType.DECODER, (
|
|
"Only decoder-only models support prefix caching")
|
|
assert prefill_meta.seq_lens is not None
|
|
max_seq_len = max(prefill_meta.seq_lens)
|
|
mlu_ops.flash_attention(query,
|
|
key_cache,
|
|
value_cache,
|
|
output[:num_prefill_kv_tokens],
|
|
prefill_meta.query_start_loc,
|
|
prefill_meta.seq_start_loc,
|
|
alibi_slopes,
|
|
None,
|
|
prefill_meta.max_query_len,
|
|
max_seq_len,
|
|
softmax_scale,
|
|
True,
|
|
-1 if window_size is None \
|
|
else window_size[0],
|
|
-1 if window_size is None \
|
|
else window_size[1],
|
|
torch.float,
|
|
False,
|
|
prefill_meta.block_tables)
|
|
|
|
if decode_meta := attn_metadata.decode_metadata:
|
|
# Decoding run.
|
|
alibi_slopes = None if alibi_slopes is None \
|
|
else alibi_slopes.repeat(attn_metadata.num_decode_tokens, 1)
|
|
decode_query = decode_query.view(-1, 1, num_heads, head_size)
|
|
decode_out = output[num_prefill_query_tokens:].view(-1, 1, num_heads, head_size)
|
|
# Use flash_attn_varlen_func kernel for speculative decoding
|
|
# because different queries might have different lengths.
|
|
assert decode_meta.max_decode_query_len is not None
|
|
if decode_meta.max_decode_query_len > 1:
|
|
assert attn_type == AttentionType.DECODER, (
|
|
"Only decoder-only models support max_decode_query_len > 1")
|
|
mlu_ops.flash_attention(decode_query,
|
|
key_cache,
|
|
value_cache,
|
|
decode_out,
|
|
decode_meta.query_start_loc,
|
|
decode_meta.seq_start_loc,
|
|
alibi_slopes,
|
|
None,
|
|
decode_meta.max_decode_query_len,
|
|
decode_meta.max_decode_seq_len,
|
|
softmax_scale,
|
|
True,
|
|
-1 if window_size is None \
|
|
else window_size[0],
|
|
-1 if window_size is None \
|
|
else window_size[1],
|
|
torch.float,
|
|
False,
|
|
decode_meta.block_tables)
|
|
else:
|
|
# Use flash_attn_with_kvcache for normal decoding.
|
|
(
|
|
seq_lens_arg,
|
|
max_context_len,
|
|
block_tables_arg,
|
|
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
|
|
if use_mla:
|
|
value_cache = key_cache
|
|
value_cache_scale = key_cache_scale
|
|
mlu_ops.single_query_cached_kv_attn(decode_query,
|
|
key_cache,
|
|
value_cache,
|
|
decode_out,
|
|
block_tables_arg,
|
|
seq_lens_arg,
|
|
key_cache_scale,
|
|
value_cache_scale,
|
|
alibi_slopes,
|
|
max_context_len,
|
|
-1 if window_size is None \
|
|
else window_size[0],
|
|
-1 if window_size is None \
|
|
else window_size[1],
|
|
softmax_scale)
|
|
|
|
# Reshape the output tensor.
|
|
if use_mla and attn_metadata.prefill_metadata:
|
|
return output.view(num_tokens, num_kv_heads * v_head_size)
|
|
return output.view(num_tokens, hidden_size)
|
|
|
|
|
|
def unified_flash_attention_v2_fake(
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
num_heads: int,
|
|
head_size: int,
|
|
num_kv_heads: int,
|
|
kv_cache: List[torch.Tensor],
|
|
kv_cache_dtype: str,
|
|
k_scale: float,
|
|
v_scale: float,
|
|
softmax_scale: float,
|
|
attn_type_int_val: int,
|
|
window_size: Optional[List[int]] = None,
|
|
alibi_slopes: Optional[torch.Tensor] = None,
|
|
logits_soft_cap: Optional[float] = None,
|
|
) -> torch.Tensor:
|
|
return torch.empty_like(query)
|
|
|
|
|
|
direct_register_custom_op(
|
|
op_name="unified_flash_attention_v2",
|
|
op_func=unified_flash_attention_v2,
|
|
mutates_args=["kv_cache"],
|
|
fake_impl=unified_flash_attention_v2_fake,
|
|
)
|
|
|
|
|
|
def make_tensor_without_pad(
|
|
x: List[List[int]],
|
|
dtype: torch.dtype,
|
|
device: Union[str, torch.device] = "mlu",
|
|
pin_memory: bool = False,
|
|
) -> torch.Tensor:
|
|
return torch.tensor(x,
|
|
dtype=dtype,
|
|
device=device,
|
|
pin_memory=pin_memory and str(device) == "cpu")
|
|
|
|
|
|
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
|
|
MLUFlashAttentionBackend.get_supported_head_sizes,
|
|
MLUFlashAttentionBackend_V2.get_supported_head_sizes)
|
|
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
|
|
MLUFlashAttentionBackend.get_impl_cls,
|
|
MLUFlashAttentionBackend_V2.get_impl_cls)
|
|
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
|
|
MLUFlashAttentionBackend.get_builder_cls,
|
|
MLUFlashAttentionBackend_V2.get_builder_cls)
|
|
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
|
|
MLUFlashAttentionBackend.get_state_cls,
|
|
MLUFlashAttentionBackend_V2.get_state_cls)
|
|
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
|
|
"get_kv_cache_scale_shape",
|
|
MLUFlashAttentionBackend_V2.get_kv_cache_scale_shape)
|
|
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
|
|
MLUFlashAttentionBackend.copy_blocks,
|
|
MLUFlashAttentionBackend_V2.copy_blocks)
|