Files
enginex-mlu370-vllm/vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/backends/mlu_attn.py
2026-02-11 17:47:15 +08:00

819 lines
35 KiB
Python

import torch
from contextlib import contextmanager
from itertools import accumulate
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch
from vllm import _mlu_ops as mlu_ops
from vllm.attention.backends.abstract import (AttentionMetadata,
AttentionType)
from vllm.attention.backends.utils import (
PAD_SLOT_ID, get_num_prefill_decode_query_kv_tokens,
get_seq_len_block_table_args)
from vllm.forward_context import get_forward_context
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
make_tensor_with_pad)
from vllm.attention.backends.mlu_attn import (
MLUFlashAttentionBackend, MLUFlashAttentionMetadataBuilder,
MLUFlashAttentionMetadata, MLUFlashAttentionImpl,
MLUFlashAttentionState, _get_query_key_seq_metadata,
_get_causal_option)
from vllm_mlu._mlu_utils import USE_PAGED
from vllm_mlu.mlu_hijack_utils import MluHijackObject
class MLUFlashAttentionBackend_V2(MLUFlashAttentionBackend):
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 80, 96, 128, 160, 192, 224, 256, 512, 576]
@staticmethod
def get_impl_cls() -> Type["MLUFlashAttentionImpl_V2"]:
return MLUFlashAttentionImpl_V2
@staticmethod
def get_builder_cls() -> Type["MLUFlashAttentionMetadataBuilder_V2"]:
return MLUFlashAttentionMetadataBuilder_V2
@staticmethod
def get_state_cls() -> Type["MLUFlashAttentionState_V2"]:
return MLUFlashAttentionState_V2
@staticmethod
def get_kv_cache_scale_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
) -> Tuple[int, ...]:
return (2, num_blocks, num_kv_heads, block_size)
@staticmethod
def copy_blocks(
kv_caches: List[List[torch.Tensor]],
src_to_dists: torch.Tensor,
) -> None:
key_caches = [kv_cache[0][0] for kv_cache in kv_caches]
value_caches = [kv_cache[0][1] for kv_cache in kv_caches]
mlu_ops.copy_blocks(key_caches, value_caches, src_to_dists)
kv_cache_scales = [kv_cache[1] for kv_cache in kv_caches]
if len(kv_cache_scales) > 0 and kv_cache_scales[0].numel() > 0:
key_cache_scales = [kv_cache_scale[0] for kv_cache_scale in kv_cache_scales]
value_cache_scales = [kv_cache_scale[1] for kv_cache_scale in kv_cache_scales]
mlu_ops.copy_blocks(key_cache_scales, value_cache_scales, src_to_dists)
class MLUMLAFlashAttentionBackend(MLUFlashAttentionBackend):
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (1, num_blocks, num_kv_heads, block_size, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
mlu_ops.swap_blocks(dst_key_cache, src_key_cache, src_to_dst)
@staticmethod
def get_kv_cache_scale_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
) -> Tuple[int, ...]:
return (1, num_blocks, num_kv_heads, block_size)
@staticmethod
def copy_blocks(
kv_caches: List[List[torch.Tensor]],
src_to_dists: torch.Tensor,
) -> None:
key_caches = [kv_cache[0][0] for kv_cache in kv_caches]
mlu_ops.copy_blocks(key_caches, None, src_to_dists)
kv_cache_scales = [kv_cache[1] for kv_cache in kv_caches]
if len(kv_cache_scales) > 0 and kv_cache_scales[0].numel() > 0:
key_cache_scales = [kv_cache_scale[0] for kv_cache_scale in kv_cache_scales]
mlu_ops.copy_blocks(key_cache_scales, None, src_to_dists)
class MLUFlashAttentionState_V2(MLUFlashAttentionState):
def __init__(self, runner: "ModelRunnerBase"):
MLUFlashAttentionState.__init__(self, runner)
@contextmanager
def graph_capture(self, max_batch_size: int):
self._is_graph_capturing = True
self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID if USE_PAGED else 0,
dtype=torch.int32,
device=self.runner.device)
self._graph_seq_lens = torch.ones(max_batch_size,
dtype=torch.int32,
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)
yield
self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
del self._graph_block_tables
def graph_capture_get_metadata_for_batch(
self, batch_size: int, is_encoder_decoder_model: bool = False):
assert self._is_graph_capturing
attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=1,
max_decode_query_len=1,
max_prefill_seq_len=0,
max_decode_seq_len=(self.runner.max_seq_len_to_capture if USE_PAGED
else min(self.runner.block_size, self.runner.max_seq_len_to_capture)),
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self._graph_block_tables[:batch_size],
use_cuda_graph=True,
)
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
assert self.runner.attn_backend.get_name() in\
["XFORMERS", "FLASH_ATTN"], \
f"Expected attn_backend name to be either 'XFORMERS' or " \
f"'FLASH_ATTN', but "\
f"got '{self.runner.attn_backend.get_name()}'"
self._update_captured_metadata_for_enc_dec_model(
batch_size=batch_size, attn_metadata=attn_metadata)
return attn_metadata
def get_graph_input_buffers(
self,
attn_metadata: AttentionMetadata,
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
input_buffers = {
"slot_mapping": attn_metadata.slot_mapping,
"seq_lens_tensor": None,
"block_tables": None,
}
if attn_metadata.num_prefills > 0:
input_buffers["seq_lens_tensor"] = attn_metadata.prefill_metadata.seq_lens_tensor
input_buffers["block_tables"] = attn_metadata.prefill_metadata.block_tables
else:
input_buffers["seq_lens_tensor"] = attn_metadata.decode_metadata.seq_lens_tensor
input_buffers["block_tables"] = attn_metadata.decode_metadata.block_tables
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
assert self.runner.attn_backend.get_name() in\
["XFORMERS", "FLASH_ATTN"], \
f"Expected attn_backend name to be either 'XFORMERS' or "\
f"'FLASH_ATTN', but "\
f"got '{self.runner.attn_backend.get_name()}'"
self._add_additonal_input_buffers_for_enc_dec_model(
attn_metadata=attn_metadata, input_buffers=input_buffers)
return input_buffers
def prepare_graph_input_buffers(
self,
input_buffers: Dict[str, Any],
attn_metadata: AttentionMetadata,
is_encoder_decoder_model: bool = False) -> None:
metadata = attn_metadata.prefill_metadata if \
attn_metadata.num_prefills > 0 else attn_metadata.decode_metadata
input_buffers["seq_lens_tensor"].copy_(
metadata.seq_lens_tensor, non_blocking=True)
input_buffers["block_tables"].copy_(
metadata.block_tables, non_blocking=True)
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
assert self.runner.attn_backend.get_name() in\
["XFORMERS", "FLASH_ATTN"], \
f"Expected attn_backend name to be either 'XFORMERS' or "\
f"'FLASH_ATTN', but "\
f"got '{self.runner.attn_backend.get_name()}'"
self._prepare_input_buffers_for_enc_dec_model(
attn_metadata, input_buffers)
@contextmanager
def graph_capture_with_context(
self,
ctx_graph_batch_size: int,
max_batch_size: int,
max_num_tokens: int
):
self._is_graph_capturing = True
self._graph_slot_mapping = torch.full((max_num_tokens, ),
PAD_SLOT_ID if USE_PAGED else 0,
dtype=torch.int32,
device=self.runner.device)
self._graph_seq_lens = torch.ones(max_batch_size,
dtype=torch.int32,
device=self.runner.device)
# block tables used for decode mlugraph input buffer
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)
# block tables used for context mlugraph input buffer
self._ctx_graph_block_tables = torch.zeros((ctx_graph_batch_size, 0),
dtype=self._graph_block_tables.dtype,
device=self.runner.device)
yield
self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
del self._graph_block_tables
del self._ctx_graph_block_tables
def fill_seq_lens_tensor(
self,
seq_len: int
) -> None:
self._graph_seq_lens.fill_(seq_len)
def graph_capture_get_metadata_for_context(
self,
batch_size: int,
seq_len: int,
is_encoder_decoder_model: bool = False
) -> MLUFlashAttentionMetadata:
assert self._is_graph_capturing
query_start_loc = torch.zeros(batch_size + 1,
dtype=torch.int32,
device=self.runner.device)
seq_start_loc = torch.zeros(batch_size + 1,
dtype=torch.int32,
device=self.runner.device)
context_lens_tensor = torch.zeros(batch_size,
dtype=torch.int32,
device=self.runner.device)
torch.cumsum(self._graph_seq_lens[:batch_size],
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
torch.cumsum(self._graph_seq_lens[:batch_size],
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
num_tokens = batch_size * seq_len
attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=batch_size,
num_prefill_tokens=num_tokens,
num_decode_tokens=0,
slot_mapping=self._graph_slot_mapping[:num_tokens],
multi_modal_placeholder_index_maps=None,
seq_lens=[seq_len] * batch_size,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=seq_len,
max_decode_query_len=0,
max_prefill_seq_len=seq_len,
max_decode_seq_len=0,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=self._ctx_graph_block_tables,
use_cuda_graph=True,
)
return attn_metadata
class MLUFlashAttentionMetadataBuilder_V2(MLUFlashAttentionMetadataBuilder):
def build(
self,
seq_lens: List[int],
query_lens: List[int],
cuda_graph_pad_size: int,
batch_size: int
) -> MLUFlashAttentionMetadata:
'''
=============================
Modify by vllm_mlu
=============================
@brief: Use origin func if do not use context mlugraph.
'''
if not self.runner.model_config.use_context_mlugraph():
return super().build(seq_lens,
query_lens,
cuda_graph_pad_size,
batch_size)
'''
==================
End of MLU Hijack
==================
'''
prefix_cache_hit = any([
inter_data.prefix_cache_hit
for inter_data in self.input_builder.inter_data_list
])
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled,
prefix_cache_hit)
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
max_query_len = max(query_lens)
decode_query_lens = query_lens[self.num_prefills:]
if len(decode_query_lens) > 0:
max_decode_query_len = max(decode_query_lens)
else:
max_decode_query_len = 1
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
query_start_loc = list(accumulate(query_lens, initial=0))
seq_start_loc = list(accumulate(seq_lens, initial=0))
num_seqs = len(seq_lens)
if use_captured_graph:
self.slot_mapping.extend([
(PAD_SLOT_ID if USE_PAGED else 0)] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size - self.num_prefill_tokens
block_tables = self._get_graph_runner_block_tables(
num_seqs, self.block_tables)
else:
if USE_PAGED:
block_tables = make_tensor_with_pad(
self.block_tables,
pad=0,
dtype=torch.int,
device=device,
)
else:
block_tables = make_tensor_without_pad(
self.block_tables,
dtype=torch.int,
device=device,
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
assert device is not None
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int32,
device, self.runner.pin_memory)
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
device,
self.runner.pin_memory)
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
device, self.runner.pin_memory)
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
self.multimodal_placeholder_maps.items()
}
'''
=============================
Modify by vllm_mlu
=============================
@brief: Check if we can use context mlugraph for the given input.
'''
if num_decode_tokens == 0 and self.num_prefills > 0:
ctx_graph_bs, ctx_graph_seq_len = (
self.runner.model_config.get_context_mlugraph_bs_and_seq())
use_captured_graph = len(seq_lens) == ctx_graph_bs and all(
seq_len == ctx_graph_seq_len for seq_len in seq_lens)
'''
==================
End of MLU Hijack
==================
'''
return MLUFlashAttentionMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
multi_modal_placeholder_index_maps=placeholder_index_maps,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_decode_query_len=max_decode_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc_tensor,
seq_start_loc=seq_start_loc_tensor,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
class MLUFlashAttentionImpl_V2(MLUFlashAttentionImpl):
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: List[torch.Tensor],
attn_metadata: MLUFlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
use_mla: bool = False,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "
"encoder metadata attributes.")
elif (attn_type == AttentionType.ENCODER_DECODER
and (not attn_metadata.is_all_cross_attn_metadata_set)):
raise AttributeError("Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes.")
output = torch.ops.vllm.unified_flash_attention_v2(
query,
key,
value,
self.num_heads,
self.head_size,
self.num_kv_heads,
kv_cache,
self.kv_cache_dtype,
k_scale,
v_scale,
self.scale,
attn_type.value,
self.sliding_window,
self.alibi_slopes,
self.logits_soft_cap,
use_mla,
)
return output
def unified_flash_attention_v2(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: List[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
attn_type_int_val: int,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
use_mla: bool = False,
) -> torch.Tensor:
# Convert integer attn_type to enum
try:
attn_type = AttentionType(attn_type_int_val)
except ValueError as err:
raise AttributeError(
f"Invalid attention type {str(attn_type_int_val)}") from err
current_metadata = get_forward_context()
assert current_metadata is not None
assert isinstance(current_metadata, MLUFlashAttentionMetadata)
attn_metadata: MLUFlashAttentionMetadata = current_metadata
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, num_heads, head_size)
v_head_size = value.size(1) // num_kv_heads
if (key is not None) and (value is not None):
key = key.view(-1, num_kv_heads, head_size)
if use_mla and attn_metadata.prefill_metadata:
value = value.view(-1, num_kv_heads, v_head_size)
else:
value = value.view(-1, num_kv_heads, head_size)
if kv_cache[0].numel() > 0:
kv_cache_, kv_cache_scale_ = kv_cache
key_cache = kv_cache_[0]
value_cache = None if use_mla else kv_cache_[1]
key_cache_scale, value_cache_scale = None, None
if kv_cache_scale_.numel() > 0:
key_cache_scale = kv_cache_scale_[0]
value_cache_scale = None if use_mla else kv_cache_scale_[1]
# We skip updating the KV cache under two conditions:
# a. When the Attention Type is ENCODER. In this phase, we compute
# only the encoder attention without updating the cache.
# b. When both Key and Value are None. This occurs during
# cross-attention computation in the decoding phase, where the KV
# cache is already populated with the cross-attention tensor.
# Thus, we skip cache updates during this time.
if (attn_type != AttentionType.ENCODER) and (key is not None) and (
value is not None):
if attn_type == AttentionType.ENCODER_DECODER:
# Update cross-attention KV cache (prefill-only)
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
if USE_PAGED:
value_to_cache = None if use_mla else value
if use_mla and attn_metadata.prefill_metadata:
# MLA save cache info in models before flashattn
pass
else:
if kv_cache_dtype == 'int8':
mlu_ops.quant_to_paged_cache(key,
value_to_cache,
key_cache,
value_cache,
key_cache_scale,
value_cache_scale,
updated_slot_mapping.flatten())
else:
mlu_ops.reshape_paged_cache(key,
value_to_cache,
key_cache,
value_cache,
updated_slot_mapping.flatten())
else:
# unpaged (linear cache) path
if use_mla:
# MLA: 镜像 paged 路径的处理方式
# key_cache: (num_blocks, 1, block_size, 576)
value_to_cache = None
if attn_metadata.prefill_metadata:
# MLA prefill cache 已在 forward_prefill 中写入,跳过
pass
else:
if kv_cache_dtype == 'int8':
mlu_ops.quant_to_paged_cache(
key, value_to_cache,
key_cache, value_cache,
key_cache_scale, value_cache_scale,
updated_slot_mapping.flatten())
else:
mlu_ops.reshape_paged_cache(
key, value_to_cache,
key_cache, value_cache,
updated_slot_mapping.flatten())
else:
# FIXME: After TMO-1496 is completed, remove this code.
if key.stride() != value.stride():
key = key.contiguous()
value = value.contiguous()
if kv_cache_dtype == 'int8':
mlu_ops.quant_to_linear_cache(
key, value,
key_cache, value_cache,
key_cache_scale, value_cache_scale,
attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, None,
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
else:
mlu_ops.reshape_linear_cache(
key, value,
key_cache, value_cache,
attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, None,
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
if use_mla and attn_metadata.prefill_metadata:
output = torch.empty(query.shape[0], query.shape[1], v_head_size, dtype=query.dtype, device="mlu")
else:
output = torch.empty_like(query)
(num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens) = \
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
decode_query = query[num_prefill_query_tokens:]
# QKV for prefill.
query = query[:num_prefill_query_tokens]
assert query.shape[0] == num_prefill_query_tokens
assert decode_query.shape[0] == num_decode_query_tokens
if prefill_meta := attn_metadata.prefill_metadata:
alibi_slopes = None if alibi_slopes is None else \
alibi_slopes.repeat(attn_metadata.num_prefills, 1)
# Prompt run.
if (kv_cache[0].numel() == 0 or prefill_meta.block_tables is None
or prefill_meta.block_tables.numel() == 0):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
_get_query_key_seq_metadata(prefill_meta, True, attn_type)
key = key[:num_prefill_kv_tokens]
value = value[:num_prefill_kv_tokens]
mlu_ops.flash_attention(query,
key,
value,
output[:num_prefill_query_tokens],
q_seq_start_loc,
k_seq_start_loc,
alibi_slopes,
None,
q_seq_len,
k_seq_len,
softmax_scale,
_get_causal_option(attn_type),
-1 if window_size is None \
else window_size[0],
-1 if window_size is None \
else window_size[1],
torch.float,
False)
else:
# prefix-enabled attention
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support prefix caching")
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
mlu_ops.flash_attention(query,
key_cache,
value_cache,
output[:num_prefill_kv_tokens],
prefill_meta.query_start_loc,
prefill_meta.seq_start_loc,
alibi_slopes,
None,
prefill_meta.max_query_len,
max_seq_len,
softmax_scale,
True,
-1 if window_size is None \
else window_size[0],
-1 if window_size is None \
else window_size[1],
torch.float,
False,
prefill_meta.block_tables)
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
alibi_slopes = None if alibi_slopes is None \
else alibi_slopes.repeat(attn_metadata.num_decode_tokens, 1)
decode_query = decode_query.view(-1, 1, num_heads, head_size)
decode_out = output[num_prefill_query_tokens:].view(-1, 1, num_heads, head_size)
# Use flash_attn_varlen_func kernel for speculative decoding
# because different queries might have different lengths.
assert decode_meta.max_decode_query_len is not None
if decode_meta.max_decode_query_len > 1:
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support max_decode_query_len > 1")
mlu_ops.flash_attention(decode_query,
key_cache,
value_cache,
decode_out,
decode_meta.query_start_loc,
decode_meta.seq_start_loc,
alibi_slopes,
None,
decode_meta.max_decode_query_len,
decode_meta.max_decode_seq_len,
softmax_scale,
True,
-1 if window_size is None \
else window_size[0],
-1 if window_size is None \
else window_size[1],
torch.float,
False,
decode_meta.block_tables)
else:
# Use flash_attn_with_kvcache for normal decoding.
(
seq_lens_arg,
max_context_len,
block_tables_arg,
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
if use_mla:
value_cache = key_cache
value_cache_scale = key_cache_scale
mlu_ops.single_query_cached_kv_attn(decode_query,
key_cache,
value_cache,
decode_out,
block_tables_arg,
seq_lens_arg,
key_cache_scale,
value_cache_scale,
alibi_slopes,
max_context_len,
-1 if window_size is None \
else window_size[0],
-1 if window_size is None \
else window_size[1],
softmax_scale)
# Reshape the output tensor.
if use_mla and attn_metadata.prefill_metadata:
return output.view(num_tokens, num_kv_heads * v_head_size)
return output.view(num_tokens, hidden_size)
def unified_flash_attention_v2_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: List[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
attn_type_int_val: int,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query)
direct_register_custom_op(
op_name="unified_flash_attention_v2",
op_func=unified_flash_attention_v2,
mutates_args=["kv_cache"],
fake_impl=unified_flash_attention_v2_fake,
)
def make_tensor_without_pad(
x: List[List[int]],
dtype: torch.dtype,
device: Union[str, torch.device] = "mlu",
pin_memory: bool = False,
) -> torch.Tensor:
return torch.tensor(x,
dtype=dtype,
device=device,
pin_memory=pin_memory and str(device) == "cpu")
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
MLUFlashAttentionBackend.get_supported_head_sizes,
MLUFlashAttentionBackend_V2.get_supported_head_sizes)
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
MLUFlashAttentionBackend.get_impl_cls,
MLUFlashAttentionBackend_V2.get_impl_cls)
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
MLUFlashAttentionBackend.get_builder_cls,
MLUFlashAttentionBackend_V2.get_builder_cls)
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
MLUFlashAttentionBackend.get_state_cls,
MLUFlashAttentionBackend_V2.get_state_cls)
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
"get_kv_cache_scale_shape",
MLUFlashAttentionBackend_V2.get_kv_cache_scale_shape)
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
MLUFlashAttentionBackend.copy_blocks,
MLUFlashAttentionBackend_V2.copy_blocks)