forked from EngineX-Cambricon/enginex-mlu370-vllm
add qwen3
This commit is contained in:
802
vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/backends/mlu_attn.py
Normal file
802
vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/backends/mlu_attn.py
Normal file
@@ -0,0 +1,802 @@
|
||||
import torch
|
||||
|
||||
from contextlib import contextmanager
|
||||
from itertools import accumulate
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from vllm.attention.backends.abstract import (AttentionMetadata,
|
||||
AttentionType)
|
||||
from vllm.attention.backends.utils import (
|
||||
PAD_SLOT_ID, get_num_prefill_decode_query_kv_tokens,
|
||||
get_seq_len_block_table_args)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
|
||||
make_tensor_with_pad)
|
||||
from vllm.attention.backends.mlu_attn import (
|
||||
MLUFlashAttentionBackend, MLUFlashAttentionMetadataBuilder,
|
||||
MLUFlashAttentionMetadata, MLUFlashAttentionImpl,
|
||||
MLUFlashAttentionState, _get_query_key_seq_metadata,
|
||||
_get_causal_option)
|
||||
from vllm_mlu._mlu_utils import USE_PAGED
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
class MLUFlashAttentionBackend_V2(MLUFlashAttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [32, 64, 80, 96, 128, 160, 192, 224, 256, 512, 576]
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["MLUFlashAttentionImpl_V2"]:
|
||||
return MLUFlashAttentionImpl_V2
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["MLUFlashAttentionMetadataBuilder_V2"]:
|
||||
return MLUFlashAttentionMetadataBuilder_V2
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["MLUFlashAttentionState_V2"]:
|
||||
return MLUFlashAttentionState_V2
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_scale_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, num_kv_heads, block_size)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[List[torch.Tensor]],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0][0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[0][1] for kv_cache in kv_caches]
|
||||
mlu_ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
|
||||
kv_cache_scales = [kv_cache[1] for kv_cache in kv_caches]
|
||||
if len(kv_cache_scales) > 0 and kv_cache_scales[0].numel() > 0:
|
||||
key_cache_scales = [kv_cache_scale[0] for kv_cache_scale in kv_cache_scales]
|
||||
value_cache_scales = [kv_cache_scale[1] for kv_cache_scale in kv_cache_scales]
|
||||
mlu_ops.copy_blocks(key_cache_scales, value_cache_scales, src_to_dists)
|
||||
|
||||
|
||||
class MLUMLAFlashAttentionBackend(MLUFlashAttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (1, num_blocks, num_kv_heads, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
src_key_cache = src_kv_cache[0]
|
||||
dst_key_cache = dst_kv_cache[0]
|
||||
mlu_ops.swap_blocks(dst_key_cache, src_key_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_scale_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (1, num_blocks, num_kv_heads, block_size)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[List[torch.Tensor]],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0][0] for kv_cache in kv_caches]
|
||||
mlu_ops.copy_blocks(key_caches, None, src_to_dists)
|
||||
|
||||
kv_cache_scales = [kv_cache[1] for kv_cache in kv_caches]
|
||||
if len(kv_cache_scales) > 0 and kv_cache_scales[0].numel() > 0:
|
||||
key_cache_scales = [kv_cache_scale[0] for kv_cache_scale in kv_cache_scales]
|
||||
mlu_ops.copy_blocks(key_cache_scales, None, src_to_dists)
|
||||
|
||||
|
||||
class MLUFlashAttentionState_V2(MLUFlashAttentionState):
|
||||
|
||||
def __init__(self, runner: "ModelRunnerBase"):
|
||||
MLUFlashAttentionState.__init__(self, runner)
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
self._is_graph_capturing = True
|
||||
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
||||
PAD_SLOT_ID if USE_PAGED else 0,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
self._graph_seq_lens = torch.ones(max_batch_size,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
self._graph_block_tables = torch.from_numpy(
|
||||
self.runner.graph_block_tables).to(device=self.runner.device)
|
||||
yield
|
||||
self._is_graph_capturing = False
|
||||
del self._graph_slot_mapping
|
||||
del self._graph_seq_lens
|
||||
del self._graph_block_tables
|
||||
|
||||
def graph_capture_get_metadata_for_batch(
|
||||
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
||||
assert self._is_graph_capturing
|
||||
attn_metadata = self.runner.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size,
|
||||
slot_mapping=self._graph_slot_mapping[:batch_size],
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
||||
max_query_len=1,
|
||||
max_decode_query_len=1,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=(self.runner.max_seq_len_to_capture if USE_PAGED
|
||||
else min(self.runner.block_size, self.runner.max_seq_len_to_capture)),
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=self._graph_block_tables[:batch_size],
|
||||
use_cuda_graph=True,
|
||||
)
|
||||
if is_encoder_decoder_model:
|
||||
# The encoder decoder model works only with XFormers and
|
||||
# Flash Attention backend. Assert the same.
|
||||
assert self.runner.attn_backend.get_name() in\
|
||||
["XFORMERS", "FLASH_ATTN"], \
|
||||
f"Expected attn_backend name to be either 'XFORMERS' or " \
|
||||
f"'FLASH_ATTN', but "\
|
||||
f"got '{self.runner.attn_backend.get_name()}'"
|
||||
self._update_captured_metadata_for_enc_dec_model(
|
||||
batch_size=batch_size, attn_metadata=attn_metadata)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def get_graph_input_buffers(
|
||||
self,
|
||||
attn_metadata: AttentionMetadata,
|
||||
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
|
||||
input_buffers = {
|
||||
"slot_mapping": attn_metadata.slot_mapping,
|
||||
"seq_lens_tensor": None,
|
||||
"block_tables": None,
|
||||
}
|
||||
if attn_metadata.num_prefills > 0:
|
||||
input_buffers["seq_lens_tensor"] = attn_metadata.prefill_metadata.seq_lens_tensor
|
||||
input_buffers["block_tables"] = attn_metadata.prefill_metadata.block_tables
|
||||
else:
|
||||
input_buffers["seq_lens_tensor"] = attn_metadata.decode_metadata.seq_lens_tensor
|
||||
input_buffers["block_tables"] = attn_metadata.decode_metadata.block_tables
|
||||
if is_encoder_decoder_model:
|
||||
# The encoder decoder model works only with XFormers and
|
||||
# Flash Attention backend. Assert the same.
|
||||
assert self.runner.attn_backend.get_name() in\
|
||||
["XFORMERS", "FLASH_ATTN"], \
|
||||
f"Expected attn_backend name to be either 'XFORMERS' or "\
|
||||
f"'FLASH_ATTN', but "\
|
||||
f"got '{self.runner.attn_backend.get_name()}'"
|
||||
self._add_additonal_input_buffers_for_enc_dec_model(
|
||||
attn_metadata=attn_metadata, input_buffers=input_buffers)
|
||||
return input_buffers
|
||||
|
||||
def prepare_graph_input_buffers(
|
||||
self,
|
||||
input_buffers: Dict[str, Any],
|
||||
attn_metadata: AttentionMetadata,
|
||||
is_encoder_decoder_model: bool = False) -> None:
|
||||
metadata = attn_metadata.prefill_metadata if \
|
||||
attn_metadata.num_prefills > 0 else attn_metadata.decode_metadata
|
||||
|
||||
input_buffers["seq_lens_tensor"].copy_(
|
||||
metadata.seq_lens_tensor, non_blocking=True)
|
||||
input_buffers["block_tables"].copy_(
|
||||
metadata.block_tables, non_blocking=True)
|
||||
if is_encoder_decoder_model:
|
||||
# The encoder decoder model works only with XFormers and
|
||||
# Flash Attention backend. Assert the same.
|
||||
assert self.runner.attn_backend.get_name() in\
|
||||
["XFORMERS", "FLASH_ATTN"], \
|
||||
f"Expected attn_backend name to be either 'XFORMERS' or "\
|
||||
f"'FLASH_ATTN', but "\
|
||||
f"got '{self.runner.attn_backend.get_name()}'"
|
||||
self._prepare_input_buffers_for_enc_dec_model(
|
||||
attn_metadata, input_buffers)
|
||||
|
||||
@contextmanager
|
||||
def graph_capture_with_context(
|
||||
self,
|
||||
ctx_graph_batch_size: int,
|
||||
max_batch_size: int,
|
||||
max_num_tokens: int
|
||||
):
|
||||
self._is_graph_capturing = True
|
||||
self._graph_slot_mapping = torch.full((max_num_tokens, ),
|
||||
PAD_SLOT_ID if USE_PAGED else 0,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
self._graph_seq_lens = torch.ones(max_batch_size,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
# block tables used for decode mlugraph input buffer
|
||||
self._graph_block_tables = torch.from_numpy(
|
||||
self.runner.graph_block_tables).to(device=self.runner.device)
|
||||
# block tables used for context mlugraph input buffer
|
||||
self._ctx_graph_block_tables = torch.zeros((ctx_graph_batch_size, 0),
|
||||
dtype=self._graph_block_tables.dtype,
|
||||
device=self.runner.device)
|
||||
yield
|
||||
self._is_graph_capturing = False
|
||||
del self._graph_slot_mapping
|
||||
del self._graph_seq_lens
|
||||
del self._graph_block_tables
|
||||
del self._ctx_graph_block_tables
|
||||
|
||||
def fill_seq_lens_tensor(
|
||||
self,
|
||||
seq_len: int
|
||||
) -> None:
|
||||
self._graph_seq_lens.fill_(seq_len)
|
||||
|
||||
def graph_capture_get_metadata_for_context(
|
||||
self,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
is_encoder_decoder_model: bool = False
|
||||
) -> MLUFlashAttentionMetadata:
|
||||
assert self._is_graph_capturing
|
||||
|
||||
query_start_loc = torch.zeros(batch_size + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
seq_start_loc = torch.zeros(batch_size + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
context_lens_tensor = torch.zeros(batch_size,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
torch.cumsum(self._graph_seq_lens[:batch_size],
|
||||
dim=0,
|
||||
dtype=query_start_loc.dtype,
|
||||
out=query_start_loc[1:])
|
||||
torch.cumsum(self._graph_seq_lens[:batch_size],
|
||||
dim=0,
|
||||
dtype=seq_start_loc.dtype,
|
||||
out=seq_start_loc[1:])
|
||||
|
||||
num_tokens = batch_size * seq_len
|
||||
attn_metadata = self.runner.attn_backend.make_metadata(
|
||||
num_prefills=batch_size,
|
||||
num_prefill_tokens=num_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=self._graph_slot_mapping[:num_tokens],
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
seq_lens=[seq_len] * batch_size,
|
||||
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
||||
max_query_len=seq_len,
|
||||
max_decode_query_len=0,
|
||||
max_prefill_seq_len=seq_len,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=self._ctx_graph_block_tables,
|
||||
use_cuda_graph=True,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class MLUFlashAttentionMetadataBuilder_V2(MLUFlashAttentionMetadataBuilder):
|
||||
|
||||
def build(
|
||||
self,
|
||||
seq_lens: List[int],
|
||||
query_lens: List[int],
|
||||
cuda_graph_pad_size: int,
|
||||
batch_size: int
|
||||
) -> MLUFlashAttentionMetadata:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Use origin func if do not use context mlugraph.
|
||||
'''
|
||||
if not self.runner.model_config.use_context_mlugraph():
|
||||
return super().build(seq_lens,
|
||||
query_lens,
|
||||
cuda_graph_pad_size,
|
||||
batch_size)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
prefix_cache_hit = any([
|
||||
inter_data.prefix_cache_hit
|
||||
for inter_data in self.input_builder.inter_data_list
|
||||
])
|
||||
for inter_data in self.input_builder.inter_data_list:
|
||||
self._add_seq_group(inter_data,
|
||||
self.input_builder.chunked_prefill_enabled,
|
||||
prefix_cache_hit)
|
||||
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
decode_query_lens = query_lens[self.num_prefills:]
|
||||
if len(decode_query_lens) > 0:
|
||||
max_decode_query_len = max(decode_query_lens)
|
||||
else:
|
||||
max_decode_query_len = 1
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
query_start_loc = list(accumulate(query_lens, initial=0))
|
||||
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
||||
|
||||
num_seqs = len(seq_lens)
|
||||
if use_captured_graph:
|
||||
self.slot_mapping.extend([
|
||||
(PAD_SLOT_ID if USE_PAGED else 0)] * cuda_graph_pad_size)
|
||||
self.block_tables.extend([] * cuda_graph_pad_size)
|
||||
num_decode_tokens = batch_size - self.num_prefill_tokens
|
||||
block_tables = self._get_graph_runner_block_tables(
|
||||
num_seqs, self.block_tables)
|
||||
else:
|
||||
if USE_PAGED:
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
block_tables = make_tensor_without_pad(
|
||||
self.block_tables,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
)
|
||||
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
||||
|
||||
assert device is not None
|
||||
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
||||
device, self.runner.pin_memory)
|
||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||
self.runner.pin_memory)
|
||||
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int32,
|
||||
device, self.runner.pin_memory)
|
||||
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
||||
device,
|
||||
self.runner.pin_memory)
|
||||
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
||||
device, self.runner.pin_memory)
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
self.multimodal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Check if we can use context mlugraph for the given input.
|
||||
'''
|
||||
if num_decode_tokens == 0 and self.num_prefills > 0:
|
||||
ctx_graph_bs, ctx_graph_seq_len = (
|
||||
self.runner.model_config.get_context_mlugraph_bs_and_seq())
|
||||
use_captured_graph = len(seq_lens) == ctx_graph_bs and all(
|
||||
seq_len == ctx_graph_seq_len for seq_len in seq_lens)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
return MLUFlashAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_decode_query_len=max_decode_query_len,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
seq_start_loc=seq_start_loc_tensor,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
)
|
||||
|
||||
|
||||
class MLUFlashAttentionImpl_V2(MLUFlashAttentionImpl):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: List[torch.Tensor],
|
||||
attn_metadata: MLUFlashAttentionMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
use_mla: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
||||
assert k_scale == 1.0 and v_scale == 1.0, (
|
||||
"key/v_scale is not supported in FlashAttention.")
|
||||
|
||||
if (attn_type == AttentionType.ENCODER
|
||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||
raise AttributeError("Encoder attention requires setting "
|
||||
"encoder metadata attributes.")
|
||||
elif (attn_type == AttentionType.ENCODER_DECODER
|
||||
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
||||
raise AttributeError("Encoder/decoder cross-attention "
|
||||
"requires setting cross-attention "
|
||||
"metadata attributes.")
|
||||
|
||||
output = torch.ops.vllm.unified_flash_attention_v2(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.num_kv_heads,
|
||||
kv_cache,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
self.scale,
|
||||
attn_type.value,
|
||||
self.sliding_window,
|
||||
self.alibi_slopes,
|
||||
self.logits_soft_cap,
|
||||
use_mla,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def unified_flash_attention_v2(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
kv_cache: List[torch.Tensor],
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
softmax_scale: float,
|
||||
attn_type_int_val: int,
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
use_mla: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# Convert integer attn_type to enum
|
||||
try:
|
||||
attn_type = AttentionType(attn_type_int_val)
|
||||
except ValueError as err:
|
||||
raise AttributeError(
|
||||
f"Invalid attention type {str(attn_type_int_val)}") from err
|
||||
|
||||
current_metadata = get_forward_context()
|
||||
assert current_metadata is not None
|
||||
assert isinstance(current_metadata, MLUFlashAttentionMetadata)
|
||||
attn_metadata: MLUFlashAttentionMetadata = current_metadata
|
||||
|
||||
num_tokens, hidden_size = query.shape
|
||||
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, num_heads, head_size)
|
||||
v_head_size = value.size(1) // num_kv_heads
|
||||
if (key is not None) and (value is not None):
|
||||
key = key.view(-1, num_kv_heads, head_size)
|
||||
if use_mla and attn_metadata.prefill_metadata:
|
||||
value = value.view(-1, num_kv_heads, v_head_size)
|
||||
else:
|
||||
value = value.view(-1, num_kv_heads, head_size)
|
||||
|
||||
if kv_cache[0].numel() > 0:
|
||||
kv_cache_, kv_cache_scale_ = kv_cache
|
||||
key_cache = kv_cache_[0]
|
||||
value_cache = None if use_mla else kv_cache_[1]
|
||||
key_cache_scale, value_cache_scale = None, None
|
||||
if kv_cache_scale_.numel() > 0:
|
||||
key_cache_scale = kv_cache_scale_[0]
|
||||
value_cache_scale = None if use_mla else kv_cache_scale_[1]
|
||||
|
||||
# We skip updating the KV cache under two conditions:
|
||||
# a. When the Attention Type is ENCODER. In this phase, we compute
|
||||
# only the encoder attention without updating the cache.
|
||||
# b. When both Key and Value are None. This occurs during
|
||||
# cross-attention computation in the decoding phase, where the KV
|
||||
# cache is already populated with the cross-attention tensor.
|
||||
# Thus, we skip cache updates during this time.
|
||||
if (attn_type != AttentionType.ENCODER) and (key is not None) and (
|
||||
value is not None):
|
||||
if attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Update cross-attention KV cache (prefill-only)
|
||||
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
||||
else:
|
||||
# Update self-attention KV cache (prefill/decode)
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory profiling run.
|
||||
if USE_PAGED:
|
||||
value_to_cache = None if use_mla else value
|
||||
if use_mla and attn_metadata.prefill_metadata:
|
||||
# MLA save cache info in models before flashattn
|
||||
pass
|
||||
else:
|
||||
if kv_cache_dtype == 'int8':
|
||||
mlu_ops.quant_to_paged_cache(key,
|
||||
value_to_cache,
|
||||
key_cache,
|
||||
value_cache,
|
||||
key_cache_scale,
|
||||
value_cache_scale,
|
||||
updated_slot_mapping.flatten())
|
||||
else:
|
||||
mlu_ops.reshape_paged_cache(key,
|
||||
value_to_cache,
|
||||
key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping.flatten())
|
||||
else:
|
||||
# FIXME: After TMO-1496 is completed, remove this code.
|
||||
if key.stride() != value.stride():
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
if kv_cache_dtype == 'int8':
|
||||
mlu_ops.quant_to_linear_cache(key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
key_cache_scale,
|
||||
value_cache_scale,
|
||||
attn_metadata.cu_seq_lens,
|
||||
attn_metadata.max_seq_len,
|
||||
True, # packed
|
||||
None, # context_seq_offset
|
||||
attn_metadata.batch_ids,
|
||||
attn_metadata.slot_mapping_unpaged)
|
||||
else:
|
||||
mlu_ops.reshape_linear_cache(key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.cu_seq_lens,
|
||||
attn_metadata.max_seq_len,
|
||||
True, # packed
|
||||
None, # context_seq_offset
|
||||
attn_metadata.batch_ids,
|
||||
attn_metadata.slot_mapping_unpaged)
|
||||
if use_mla and attn_metadata.prefill_metadata:
|
||||
output = torch.empty(query.shape[0], query.shape[1], v_head_size, dtype=query.dtype, device="mlu")
|
||||
else:
|
||||
output = torch.empty_like(query)
|
||||
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||
num_decode_query_tokens) = \
|
||||
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
||||
decode_query = query[num_prefill_query_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_query_tokens]
|
||||
assert query.shape[0] == num_prefill_query_tokens
|
||||
assert decode_query.shape[0] == num_decode_query_tokens
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
alibi_slopes = None if alibi_slopes is None else \
|
||||
alibi_slopes.repeat(attn_metadata.num_prefills, 1)
|
||||
# Prompt run.
|
||||
if (kv_cache[0].numel() == 0 or prefill_meta.block_tables is None
|
||||
or prefill_meta.block_tables.numel() == 0):
|
||||
# normal attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
|
||||
_get_query_key_seq_metadata(prefill_meta, True, attn_type)
|
||||
|
||||
key = key[:num_prefill_kv_tokens]
|
||||
value = value[:num_prefill_kv_tokens]
|
||||
mlu_ops.flash_attention(query,
|
||||
key,
|
||||
value,
|
||||
output[:num_prefill_query_tokens],
|
||||
q_seq_start_loc,
|
||||
k_seq_start_loc,
|
||||
alibi_slopes,
|
||||
None,
|
||||
q_seq_len,
|
||||
k_seq_len,
|
||||
softmax_scale,
|
||||
_get_causal_option(attn_type),
|
||||
-1 if window_size is None \
|
||||
else window_size[0],
|
||||
-1 if window_size is None \
|
||||
else window_size[1],
|
||||
torch.float,
|
||||
False)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
assert attn_type == AttentionType.DECODER, (
|
||||
"Only decoder-only models support prefix caching")
|
||||
assert prefill_meta.seq_lens is not None
|
||||
max_seq_len = max(prefill_meta.seq_lens)
|
||||
mlu_ops.flash_attention(query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
output[:num_prefill_kv_tokens],
|
||||
prefill_meta.query_start_loc,
|
||||
prefill_meta.seq_start_loc,
|
||||
alibi_slopes,
|
||||
None,
|
||||
prefill_meta.max_query_len,
|
||||
max_seq_len,
|
||||
softmax_scale,
|
||||
True,
|
||||
-1 if window_size is None \
|
||||
else window_size[0],
|
||||
-1 if window_size is None \
|
||||
else window_size[1],
|
||||
torch.float,
|
||||
False,
|
||||
prefill_meta.block_tables)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Decoding run.
|
||||
alibi_slopes = None if alibi_slopes is None \
|
||||
else alibi_slopes.repeat(attn_metadata.num_decode_tokens, 1)
|
||||
decode_query = decode_query.view(-1, 1, num_heads, head_size)
|
||||
decode_out = output[num_prefill_query_tokens:].view(-1, 1, num_heads, head_size)
|
||||
# Use flash_attn_varlen_func kernel for speculative decoding
|
||||
# because different queries might have different lengths.
|
||||
assert decode_meta.max_decode_query_len is not None
|
||||
if decode_meta.max_decode_query_len > 1:
|
||||
assert attn_type == AttentionType.DECODER, (
|
||||
"Only decoder-only models support max_decode_query_len > 1")
|
||||
mlu_ops.flash_attention(decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
decode_out,
|
||||
decode_meta.query_start_loc,
|
||||
decode_meta.seq_start_loc,
|
||||
alibi_slopes,
|
||||
None,
|
||||
decode_meta.max_decode_query_len,
|
||||
decode_meta.max_decode_seq_len,
|
||||
softmax_scale,
|
||||
True,
|
||||
-1 if window_size is None \
|
||||
else window_size[0],
|
||||
-1 if window_size is None \
|
||||
else window_size[1],
|
||||
torch.float,
|
||||
False,
|
||||
decode_meta.block_tables)
|
||||
else:
|
||||
# Use flash_attn_with_kvcache for normal decoding.
|
||||
(
|
||||
seq_lens_arg,
|
||||
max_context_len,
|
||||
block_tables_arg,
|
||||
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
|
||||
if use_mla:
|
||||
value_cache = key_cache
|
||||
value_cache_scale = key_cache_scale
|
||||
mlu_ops.single_query_cached_kv_attn(decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
decode_out,
|
||||
block_tables_arg,
|
||||
seq_lens_arg,
|
||||
key_cache_scale,
|
||||
value_cache_scale,
|
||||
alibi_slopes,
|
||||
max_context_len,
|
||||
-1 if window_size is None \
|
||||
else window_size[0],
|
||||
-1 if window_size is None \
|
||||
else window_size[1],
|
||||
softmax_scale)
|
||||
|
||||
# Reshape the output tensor.
|
||||
if use_mla and attn_metadata.prefill_metadata:
|
||||
return output.view(num_tokens, num_kv_heads * v_head_size)
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
||||
|
||||
def unified_flash_attention_v2_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
kv_cache: List[torch.Tensor],
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
softmax_scale: float,
|
||||
attn_type_int_val: int,
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(query)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_flash_attention_v2",
|
||||
op_func=unified_flash_attention_v2,
|
||||
mutates_args=["kv_cache"],
|
||||
fake_impl=unified_flash_attention_v2_fake,
|
||||
)
|
||||
|
||||
|
||||
def make_tensor_without_pad(
|
||||
x: List[List[int]],
|
||||
dtype: torch.dtype,
|
||||
device: Union[str, torch.device] = "mlu",
|
||||
pin_memory: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return torch.tensor(x,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
pin_memory=pin_memory and str(device) == "cpu")
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
|
||||
MLUFlashAttentionBackend.get_supported_head_sizes,
|
||||
MLUFlashAttentionBackend_V2.get_supported_head_sizes)
|
||||
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
|
||||
MLUFlashAttentionBackend.get_impl_cls,
|
||||
MLUFlashAttentionBackend_V2.get_impl_cls)
|
||||
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
|
||||
MLUFlashAttentionBackend.get_builder_cls,
|
||||
MLUFlashAttentionBackend_V2.get_builder_cls)
|
||||
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
|
||||
MLUFlashAttentionBackend.get_state_cls,
|
||||
MLUFlashAttentionBackend_V2.get_state_cls)
|
||||
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
|
||||
"get_kv_cache_scale_shape",
|
||||
MLUFlashAttentionBackend_V2.get_kv_cache_scale_shape)
|
||||
MluHijackObject.apply_hijack(MLUFlashAttentionBackend,
|
||||
MLUFlashAttentionBackend.copy_blocks,
|
||||
MLUFlashAttentionBackend_V2.copy_blocks)
|
||||
Reference in New Issue
Block a user