init
This commit is contained in:
0
vllm_vacc/vllm/attention/backends/__init__.py
Normal file
0
vllm_vacc/vllm/attention/backends/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
0
vllm_vacc/vllm/attention/backends/mla/__init__.py
Normal file
0
vllm_vacc/vllm/attention/backends/mla/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
202
vllm_vacc/vllm/attention/backends/mla/common.py
Normal file
202
vllm_vacc/vllm/attention/backends/mla/common.py
Normal file
@@ -0,0 +1,202 @@
|
||||
|
||||
import functools
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple,
|
||||
Type, TypeVar)
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import AttentionLayer
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionState, MLAAttentionImpl)
|
||||
from vllm.attention.backends.mla.common import MLACommonMetadata,triton_attention
|
||||
|
||||
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
is_vllm_fa = True
|
||||
except ImportError:
|
||||
is_vllm_fa = False
|
||||
try:
|
||||
# For rocm use upstream flash attention
|
||||
from vllm.attention.backends.flash_attn import flash_attn_varlen_func
|
||||
except ImportError:
|
||||
flash_attn_varlen_func = None
|
||||
|
||||
T = TypeVar("T", bound="MLACommonMetadata")
|
||||
|
||||
|
||||
class MLACommonImpl():
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
# blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
q_lora_rank: Optional[int],
|
||||
kv_lora_rank: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
rotary_emb: RotaryEmbedding,
|
||||
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
|
||||
# attention backend perspective we rely on the layer to pass in the
|
||||
# correct matrix
|
||||
q_proj: ColumnParallelLinear,
|
||||
kv_b_proj: ColumnParallelLinear,
|
||||
o_proj: RowParallelLinear,
|
||||
positions: torch.Tensor = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
|
||||
self.rotary_emb = rotary_emb
|
||||
self.use_yarn_rope = isinstance(rotary_emb,
|
||||
DeepseekScalingRotaryEmbedding)
|
||||
self.q_proj = q_proj
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.o_proj = o_proj
|
||||
self.positions = positions
|
||||
|
||||
self.triton_fa_func = triton_attention
|
||||
# Handle the differences between the flash_attn_varlen from flash_attn
|
||||
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
||||
# latter has an additional parameter to control FA2 vs FA3
|
||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
if self.vllm_flash_attn_version is not None:
|
||||
self.flash_attn_varlen_func = \
|
||||
functools.partial(flash_attn_varlen_func,
|
||||
fa_version=self.vllm_flash_attn_version)
|
||||
|
||||
# For MLA the v head dim is smaller than qk head dim so we pad out
|
||||
# v with 0s to match the qk head dim for attention backends that do
|
||||
# not support different headdims
|
||||
# We don't need to pad V if we are on a hopper system with FA3
|
||||
self._pad_v = self.vllm_flash_attn_version is None or not (
|
||||
self.vllm_flash_attn_version == 3
|
||||
and current_platform.get_device_capability()[0] == 9)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
hidden_states_or_q_c: torch.Tensor, # query in unified attn
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if output is not None:
|
||||
raise NotImplementedError(
|
||||
"output is not yet supported for MLAImplBase")
|
||||
|
||||
# if attn_metadata.is_profile_run and \
|
||||
# attn_metadata.context_chunk_workspace is not None:
|
||||
# # During the profile run try to simulate to worse case output size
|
||||
# # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
|
||||
# # since this can be large
|
||||
# _ = torch.empty(
|
||||
# (attn_metadata.context_chunk_workspace.shape[0],
|
||||
# self.num_heads, self.qk_nope_head_dim + self.v_head_dim),
|
||||
# device=k_c_normed.device,
|
||||
# dtype=k_c_normed.dtype,
|
||||
# )
|
||||
|
||||
has_decode = attn_metadata.decode_metadata is not None
|
||||
has_prefill = attn_metadata.prefill_metadata is not None
|
||||
|
||||
# Restore head dim (for rotary embedding)
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
# assert hasattr(attn_metadata, "input_positions")
|
||||
if self.positions is not None:
|
||||
positions = self.positions
|
||||
elif hasattr(attn_metadata, "input_positions"):
|
||||
positions = attn_metadata.input_positions
|
||||
else:
|
||||
raise ValueError('no positions')
|
||||
|
||||
|
||||
num_prefill_tokens: int = attn_metadata.num_prefill_tokens
|
||||
|
||||
decode_hs_or_q_c = hidden_states_or_q_c[num_prefill_tokens:]
|
||||
decode_k_pe = k_pe[num_prefill_tokens:]
|
||||
decode_input_positions = \
|
||||
positions[num_prefill_tokens:]
|
||||
|
||||
prefill_hs_or_q_c = hidden_states_or_q_c[:num_prefill_tokens]
|
||||
prefill_k_pe = k_pe[:num_prefill_tokens]
|
||||
prefill_input_positions = \
|
||||
positions[:num_prefill_tokens]
|
||||
prefill_k_c_normed = k_c_normed[:num_prefill_tokens]
|
||||
|
||||
if has_decode:
|
||||
decode_ql_nope, decode_q_pe = \
|
||||
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
||||
decode_input_positions, decode_q_pe, decode_k_pe)
|
||||
|
||||
if has_prefill:
|
||||
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
||||
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
||||
prefill_input_positions, prefill_q_pe, prefill_k_pe)
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
ops.concat_and_cache_mla(
|
||||
k_c_normed,
|
||||
k_pe.squeeze(1),
|
||||
kv_cache,
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
scale=layer._k_scale,
|
||||
)
|
||||
|
||||
# output = torch.empty(attn_metadata.num_prefill_tokens +
|
||||
# attn_metadata.num_decode_tokens,
|
||||
# self.o_proj.output_size,
|
||||
# device=hidden_states_or_q_c.device,
|
||||
# dtype=hidden_states_or_q_c.dtype)
|
||||
if has_prefill:
|
||||
return self._forward_prefill(
|
||||
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
|
||||
attn_metadata)
|
||||
|
||||
if has_decode:
|
||||
return self._forward_decode(
|
||||
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
|
||||
|
||||
assert False, "mla forward need prefill or decode function"
|
||||
return None
|
||||
390
vllm_vacc/vllm/attention/backends/mla/utils.py
Normal file
390
vllm_vacc/vllm/attention/backends/mla/utils.py
Normal file
@@ -0,0 +1,390 @@
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generic, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl, T)
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsW8A8Fp8)
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
# from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
# apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
scaled_dequantize, scaled_quantize)
|
||||
import os
|
||||
|
||||
W_Q_W_QR_WUV_WUK_USE_FP8 = True
|
||||
|
||||
class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
|
||||
def is_layer_fp8(layer: LinearBase) -> bool:
|
||||
return isinstance(layer.quant_method, Fp8LinearMethod) or\
|
||||
(isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
|
||||
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8))
|
||||
|
||||
def quantization_scheme_supported(layer: LinearBase) -> bool:
|
||||
return isinstance(layer.quant_method, UnquantizedLinearMethod) or \
|
||||
is_layer_fp8(layer)
|
||||
|
||||
# TODO(lucas) This is very gross, we need a more wide scale refactor of
|
||||
# all the FP8 code with a more standard way of
|
||||
# defining schemes/group-shapes, we should also potentially force
|
||||
# quant_methods to support a decompress function
|
||||
#
|
||||
# returns input_group_shape, weight_group_shape
|
||||
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
|
||||
Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||
if isinstance(layer.quant_method, Fp8LinearMethod):
|
||||
if layer.quant_method.block_quant is not None:
|
||||
weight_block_size = \
|
||||
layer.quant_method.quant_config.weight_block_size
|
||||
# per-token-group (1, X), block-quantized (X, Y)
|
||||
return (1, weight_block_size[-1]), weight_block_size
|
||||
else:
|
||||
return (-1, -1), (-1, -1) # per-tensor, per-tensor
|
||||
elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
|
||||
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
||||
# this is hacky but we always assume the for
|
||||
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
|
||||
# we ignore if it is static-per-tensor since we are going to
|
||||
# requantize after later anyways
|
||||
strategy = layer.scheme.strategy
|
||||
if strategy == QuantizationStrategy.TENSOR:
|
||||
return (1, -1), (-1, -1) # per-token, per-tensor
|
||||
elif strategy == QuantizationStrategy.CHANNEL:
|
||||
return (1, -1), (-1, 1) # per-token, per-channel
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"QuantizationStrategy.{strategy} is not supported for "
|
||||
"fp8 MLA, please run with VLLM_MLA_DISABLE=1")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Can't determine scale group shapes for "
|
||||
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
|
||||
)
|
||||
|
||||
def get_scales(layer: LinearBase) -> torch.Tensor:
|
||||
if hasattr(layer, "weight_scale_inv"):
|
||||
return layer.weight_scale_inv
|
||||
return layer.weight_scale
|
||||
|
||||
def get_fp8_layer_weight(layer: LinearBase):
|
||||
if is_layer_fp8(layer):
|
||||
if isinstance(layer.quant_method, \
|
||||
CompressedTensorsLinearMethod) and \
|
||||
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
||||
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
|
||||
# seems to store weights as (input, output) instead of
|
||||
# (output, input) so we need to transpose
|
||||
weight = layer.weight.T # standardize to (output, input)
|
||||
else:
|
||||
weight = layer.weight
|
||||
_, weight_scale_group_shape = \
|
||||
get_scale_group_shapes_for_fp8(layer)
|
||||
scales = get_scales(layer) # 已经expand过了
|
||||
weight_scale_group_shape=weight_scale_group_shape.copy() #config中读出来的[128,128], 需要 .copy(), 否则会把config改掉
|
||||
|
||||
# 重新校准一下 weight_scale_group_shape
|
||||
if weight.shape[0] // scales.shape[0] != weight_scale_group_shape[0]:
|
||||
weight_scale_group_shape[0] = weight.shape[0] // scales.shape[0]
|
||||
|
||||
if weight.shape[1] // scales.shape[1] != weight_scale_group_shape[1]:
|
||||
weight_scale_group_shape[1] = weight.shape[1] // scales.shape[1]
|
||||
|
||||
return weight, scales
|
||||
else:
|
||||
return layer.weight, None
|
||||
|
||||
def get_fp8_layer_weight_test(layer: LinearBase):
|
||||
if is_layer_fp8(layer):
|
||||
if isinstance(layer.quant_method, \
|
||||
CompressedTensorsLinearMethod) and \
|
||||
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
||||
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
|
||||
# seems to store weights as (input, output) instead of
|
||||
# (output, input) so we need to transpose
|
||||
weight = layer.weight.T # standardize to (output, input)
|
||||
else:
|
||||
weight = layer.weight
|
||||
_, weight_scale_group_shape = \
|
||||
get_scale_group_shapes_for_fp8(layer)
|
||||
scales = get_scales(layer) # 已经expand过了
|
||||
weight_scale_group_shape=weight_scale_group_shape.copy() #config中读出来的[128,128], 需要 .copy(), 否则会把config改掉
|
||||
|
||||
# 重新校准一下 weight_scale_group_shape
|
||||
if weight.shape[0] // scales.shape[0] != weight_scale_group_shape[0]:
|
||||
weight_scale_group_shape[0] = weight.shape[0] // scales.shape[0]
|
||||
|
||||
if weight.shape[1] // scales.shape[1] != weight_scale_group_shape[1]:
|
||||
weight_scale_group_shape[1] = weight.shape[1] // scales.shape[1]
|
||||
|
||||
# for test
|
||||
weight = scaled_dequantize(weight, scales, weight_scale_group_shape)
|
||||
# print(f'{weight.shape}, {scales.shape}, {weight_scale_group_shape}')
|
||||
return weight, scales
|
||||
else:
|
||||
return layer.weight, None
|
||||
|
||||
def check_eq(name, tensor0, tensor1):
|
||||
assert tensor0.shape == tensor1.shape
|
||||
isEqual = torch.equal(tensor0.reshape([-1]).float(), tensor1.reshape([-1]).float())
|
||||
print(f"{os.getpid()} check {name} {tensor0.shape} equal: {isEqual}")
|
||||
return isEqual
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||
if is_layer_fp8(layer):
|
||||
if isinstance(layer.quant_method, \
|
||||
CompressedTensorsLinearMethod) and \
|
||||
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
||||
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
|
||||
# seems to store weights as (input, output) instead of
|
||||
# (output, input) so we need to transpose
|
||||
weight = layer.weight.T # standardize to (output, input)
|
||||
else:
|
||||
weight = layer.weight
|
||||
_, weight_scale_group_shape = \
|
||||
get_scale_group_shapes_for_fp8(layer)
|
||||
scales = get_scales(layer) # 已经expand过了
|
||||
weight_scale_group_shape=weight_scale_group_shape.copy() #config中读出来的[128,128], 需要 .copy(), 否则会把config改掉
|
||||
|
||||
# 重新校准一下 weight_scale_group_shape
|
||||
if weight.shape[0] // scales.shape[0] != weight_scale_group_shape[0]:
|
||||
weight_scale_group_shape[0] = weight.shape[0] // scales.shape[0]
|
||||
|
||||
if weight.shape[1] // scales.shape[1] != weight_scale_group_shape[1]:
|
||||
weight_scale_group_shape[1] = weight.shape[1] // scales.shape[1]
|
||||
|
||||
return scaled_dequantize(weight, scales,
|
||||
weight_scale_group_shape)
|
||||
else:
|
||||
return layer.weight
|
||||
|
||||
if not (quantization_scheme_supported(self.kv_b_proj) and\
|
||||
quantization_scheme_supported(self.q_proj) and\
|
||||
quantization_scheme_supported(self.o_proj)):
|
||||
raise NotImplementedError(
|
||||
"Only FP8 and UnquantizedLinearMethod are supported for MLA"
|
||||
", please run with VLLM_MLA_DISABLE=1")
|
||||
|
||||
weight_dtype = self.kv_b_proj.weight.dtype
|
||||
assert self.o_proj.weight.dtype == weight_dtype
|
||||
assert self.q_proj.weight.dtype == weight_dtype
|
||||
|
||||
if W_Q_W_QR_WUV_WUK_USE_FP8: #and not envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
# 512,1024(=4x256)
|
||||
kv_b_proj_weight, kv_b_proj_scale = \
|
||||
[t.T for t in get_fp8_layer_weight(self.kv_b_proj)]
|
||||
|
||||
# kv_b_proj_weight = kv_b_proj_weight.transpose(-1,-2).contiguous().transpose(-1,-2)
|
||||
N, K = kv_b_proj_weight.shape[0], kv_b_proj_weight.shape[1]
|
||||
|
||||
# 512,1024 => 512,4,256
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
)
|
||||
|
||||
kv_b_proj_scale = kv_b_proj_scale.view(
|
||||
kv_b_proj_scale.shape[0] * self.kv_lora_rank // N,
|
||||
self.num_heads,
|
||||
kv_b_proj_scale.shape[1] * N // (self.kv_lora_rank * self.num_heads),
|
||||
)
|
||||
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
W_UK = W_UK.contiguous()
|
||||
|
||||
scale_0 = kv_b_proj_scale.shape[-1] * self.qk_nope_head_dim // (self.qk_nope_head_dim + self.v_head_dim)
|
||||
scale_1 = kv_b_proj_scale.shape[-1] - scale_0
|
||||
|
||||
W_UK_scale, W_UV_scale = kv_b_proj_scale.split(
|
||||
[scale_0, scale_1], dim=-1)
|
||||
W_UK_scale = W_UK_scale.view(W_UK_scale.shape[0], -1).unsqueeze(-1).contiguous()
|
||||
W_UV_scale = W_UV_scale.view(W_UV_scale.shape[0], -1).unsqueeze(-1)
|
||||
|
||||
# weight: [1536, 768] scale: 12,6
|
||||
q_proj_weight, q_proj_scale = \
|
||||
[t.T for t in get_fp8_layer_weight(self.q_proj)]
|
||||
|
||||
#self.W_Q_QR = q_proj_weight.contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
#self.W_Q_QR_scales = q_proj_scale.reshape(12, 6, 1).repeat(1, 1, 4).reshape(12, -1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
|
||||
q_proj_weight = q_proj_weight\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
# w_q[1536, 512] + w_qr[1536, 256]
|
||||
W_Q = q_proj_weight[..., :self.qk_nope_head_dim].flatten(start_dim=1)
|
||||
W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
|
||||
.flatten(start_dim=1).contiguous()
|
||||
# w_q_scale 12,16 + w_qr_scale 12,8
|
||||
# expand: 12,6(4+2) -> 12,24(16+8)
|
||||
# Q_scale: [s0x4, s1x2, s2x2, s3x4, s4x2, s5x2]
|
||||
repeat_pattern = torch.tensor([4, 2, 2, 4, 2, 2], device=q_proj_scale.device)
|
||||
W_Q_scale = torch.repeat_interleave(q_proj_scale, repeat_pattern, dim=1)
|
||||
# Q_R_scale: [s1x2, s2x2, s4x2, s5x2]
|
||||
selected_indices = [1, 2, 4, 5]
|
||||
repeat_times = 2
|
||||
selected = q_proj_scale[:, selected_indices]
|
||||
W_QR_scale = selected.repeat_interleave(repeat_times, dim=1)
|
||||
|
||||
# temp_WQ_Scale = W_Q_scale.reshape(12, 4, -1).contiguous()
|
||||
# temp_W_QR_scale = W_QR_scale.reshape(12, 4, -1).contiguous()
|
||||
# temp_scale = torch.cat([temp_WQ_Scale, temp_W_QR_scale], dim=2).contiguous().reshape(12, -1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
# self.W_Q_QR_scales = temp_scale
|
||||
# print("W_Q_scale:", W_Q_scale.shape)
|
||||
# print("W_QR_scale:", W_QR_scale.shape)
|
||||
# print("temp_scale:", temp_scale.shape)
|
||||
# exit(0)
|
||||
|
||||
# Note: to be vnnl compatible
|
||||
# 1. expand w_uv scale for core split friendly
|
||||
if W_UV.shape[-1] % 4 == 0:
|
||||
W_UV_scale = W_UV_scale.expand((W_UV_scale.shape[0], W_UV_scale.shape[1], 4))
|
||||
# 2. change w_q, w_qr, w_uv weight&scale to K-contiguous (shape unchanged)
|
||||
W_Q = W_Q.transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
W_Q_scale = W_Q_scale.transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
W_QR = W_QR.transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
W_QR_scale = W_QR_scale.transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
|
||||
W_UV = W_UV.permute(2,1,0).contiguous().permute(2,1,0)
|
||||
W_UV_scale = W_UV_scale.permute(2,1,0).contiguous().permute(2,1,0)
|
||||
|
||||
self.W_Q = W_Q
|
||||
self.W_Q_scales = W_Q_scale
|
||||
|
||||
self.W_QR = W_QR
|
||||
self.W_QR_scales = W_QR_scale
|
||||
|
||||
# temp_Q_scale = self.W_Q_scales.contiguous()
|
||||
# temp_W_QR_scale = self.W_QR_scales.contiguous()
|
||||
# self.W_Q_QR = q_proj_weight.reshape(1536, -1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
# self.W_Q_QR_scales = torch.concat([temp_Q_scale,temp_W_QR_scale],dim=1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
#self.W_Q_QR = torch.concat([self.W_Q.contiguous(),self.W_QR.contiguous()],dim=1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
#self.W_Q_QR_scales = torch.concat([W_Q_scale,W_QR_scale],dim=1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
|
||||
self.W_UV = W_UV
|
||||
self.W_UV_scales = W_UV_scale
|
||||
|
||||
self.W_UK = W_UK
|
||||
self.W_UK_scales = W_UK_scale
|
||||
return
|
||||
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
|
||||
f"{kv_b_proj_weight.shape=}, "
|
||||
f"{self.kv_lora_rank=}, "
|
||||
f"{self.num_heads=}, "
|
||||
f"{self.qk_nope_head_dim=}, "
|
||||
f"{self.v_head_dim=}")
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
)
|
||||
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
|
||||
# can be W_Q or W_UQ depending q_lora_rank, the former if
|
||||
# q_lora_rank is None, the latter otherwise. From the Attention backend
|
||||
# perspective though we call these both W_Q and rely on the layer
|
||||
# to pass in the correct matrix
|
||||
W_Q = q_proj_weight[..., :self.qk_nope_head_dim]
|
||||
self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
|
||||
.flatten(start_dim=1).contiguous()
|
||||
|
||||
# W_QR is small so for simplicity we dont bother requantizing it
|
||||
self.W_QR = self.W_QR.to(act_dtype)
|
||||
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
assert False, "please set VLLM_MLA_PERFORM_MATRIX_ABSORPTION=0"
|
||||
requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
# This assumes it wise to requantize using the same group shapes
|
||||
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
|
||||
# weights were originally quantized
|
||||
requant_input_group_shape, requant_weight_group_shape = \
|
||||
get_scale_group_shapes_for_fp8(self.q_proj)
|
||||
assert (requant_input_group_shape, requant_weight_group_shape)\
|
||||
== get_scale_group_shapes_for_fp8(self.kv_b_proj)
|
||||
assert (requant_input_group_shape, requant_weight_group_shape)\
|
||||
== get_scale_group_shapes_for_fp8(self.o_proj)
|
||||
self.reqaunt_input_group_shape = requant_input_group_shape
|
||||
self.reqaunt_weight_group_shape = requant_weight_group_shape
|
||||
|
||||
#
|
||||
# Perform matrix-absorption following
|
||||
# https://github.com/flashinfer-ai/flashinfer/pull/551
|
||||
# for decode, as a result we end up with absorbed weights for decode
|
||||
# and another copy of raw weights for prefill.
|
||||
#
|
||||
self.W_UK, self.W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
# We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK
|
||||
# depending q_lora_rank, the former if q_lora_rank is None, the
|
||||
# latter otherwise
|
||||
# basically if q_lora_rank is none we are absorbing into q_proj
|
||||
# instead of UQ
|
||||
W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
|
||||
.flatten(start_dim=1).contiguous()
|
||||
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
W_Q_UK, W_Q_UK_scales = scaled_quantize(
|
||||
W_Q_UK,
|
||||
self.reqaunt_weight_group_shape,
|
||||
quant_dtype=current_platform_fp8_dtype)
|
||||
# For FP8 save the transpose so we can use
|
||||
# `apply_w8a8_block_fp8_linear` directly
|
||||
self.W_Q_UK = W_Q_UK.T.contiguous()
|
||||
self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous()
|
||||
else:
|
||||
self.W_Q_UK = W_Q_UK.to(act_dtype)
|
||||
|
||||
W_O = get_and_maybe_dequant_weights(self.o_proj)\
|
||||
.view(-1, self.num_heads, self.v_head_dim)
|
||||
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
|
||||
.flatten(start_dim=0, end_dim=1).contiguous()
|
||||
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
W_UV_O, W_UV_O_scales = scaled_quantize(
|
||||
W_UV_O,
|
||||
self.reqaunt_weight_group_shape,
|
||||
quant_dtype=current_platform_fp8_dtype)
|
||||
# For FP8 save the transpose so we can use
|
||||
# `apply_w8a8_block_fp8_linear` directly
|
||||
self.W_UV_O = W_UV_O.T.contiguous()
|
||||
self.W_UV_O_scales = W_UV_O_scales.T.contiguous()
|
||||
else:
|
||||
self.W_UV_O = W_UV_O.to(act_dtype)
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
else:
|
||||
# print('W_UV', W_UV.dtype) #float32
|
||||
#if is_fp8(weight_dtype):
|
||||
# raise NotImplementedError(
|
||||
# "Currently fp8 requires matrix absorption")
|
||||
# self.W_UV = W_UV
|
||||
# self.W_UK = W_UK
|
||||
self.W_UV = W_UV.to(act_dtype) # fp32 to bfp16
|
||||
self.W_UK = W_UK.to(act_dtype)
|
||||
W_Q = W_Q.to(act_dtype)
|
||||
self.W_Q = W_Q.flatten(start_dim=1)
|
||||
726
vllm_vacc/vllm/attention/backends/vacc_attn.py
Normal file
726
vllm_vacc/vllm/attention/backends/vacc_attn.py
Normal file
@@ -0,0 +1,726 @@
|
||||
""" Attention layer with torch scaled_dot_product_attention
|
||||
and PagedAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
# from vllm.attention.backends.utils import CommonAttentionState
|
||||
# from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
||||
|
||||
from vllm_vacc.vllm.attention.ops.vacc_paged_attn import VaccPagedAttention as PagedAttention
|
||||
|
||||
# from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
# AttentionLayer,
|
||||
# AttentionMetadata,
|
||||
# AttentionMetadataBuilder,
|
||||
# AttentionType)
|
||||
# from vllm.attention.backends.utils import CommonAttentionState
|
||||
# from vllm.attention.ops.ipex_attn import PagedAttention
|
||||
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm_vacc.vllm.v1.worker.vacc_model_runner import ModelInputForVACCBuilder
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
import os
|
||||
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
||||
|
||||
class VACCAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TORCH_VACC"
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["VACCAttentionBackendImpl"]:
|
||||
return VACCAttentionBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return VACCAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["VACCMetadataBuilder"]:
|
||||
return VACCMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VACCAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
"""Metadata for VACCAttentionMetadata.
|
||||
"""
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
chunked_prefill: bool
|
||||
seq_lens: Optional[List[int]] = None # For non-chunked prefill
|
||||
# For chunked prefill only
|
||||
max_query_len: Optional[int] = None
|
||||
max_kv_len: Optional[int] = None
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
kv_start_loc: Optional[torch.Tensor] = None
|
||||
prefill_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
# Begin encoder attn & enc/dec cross-attn fields...
|
||||
# Encoder sequence lengths representation
|
||||
encoder_seq_lens: Optional[List[int]] = None
|
||||
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# Maximum sequence length among encoder sequences
|
||||
max_encoder_seq_len: Optional[int] = None
|
||||
|
||||
# Number of tokens input to encoder
|
||||
num_encoder_tokens: Optional[int] = None
|
||||
|
||||
# Cross-attention memory-mapping data structures: slot mapping
|
||||
# and block tables
|
||||
cross_slot_mapping: Optional[torch.Tensor] = None
|
||||
cross_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Set during the execution of the first attention op.
|
||||
# It is a list because it is needed to set per prompt
|
||||
# when alibi slopes is used. It is because of the limitation
|
||||
# from xformer API.
|
||||
# will not appear in the __repr__ and __init__
|
||||
self.attn_bias: Optional[List[torch.Tensor]] = None
|
||||
self.encoder_attn_bias: Optional[List[torch.Tensor]] = None
|
||||
self.cross_attn_bias: Optional[List[torch.Tensor]] = None
|
||||
|
||||
@property
|
||||
def is_all_encoder_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for encoder attention is set.
|
||||
'''
|
||||
return ((self.encoder_seq_lens is not None)
|
||||
and (self.encoder_seq_lens_tensor is not None)
|
||||
and (self.max_encoder_seq_len is not None))
|
||||
|
||||
@property
|
||||
def is_all_cross_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for enc/dec cross-attention is set.
|
||||
|
||||
Superset of encoder attention required metadata.
|
||||
'''
|
||||
return (self.is_all_encoder_attn_metadata_set
|
||||
and (self.cross_slot_mapping is not None)
|
||||
and (self.cross_block_tables is not None))
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["VACCAttentionMetadata"]:
|
||||
# Currently chunked prefill is not supported
|
||||
if self.num_prefill_tokens == 0:
|
||||
return None
|
||||
return self
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["VACCAttentionMetadata"]:
|
||||
# Currently chunked prefill is not supported
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
return self
|
||||
|
||||
def get_seq_lens(
|
||||
self,
|
||||
attn_type: AttentionType,
|
||||
):
|
||||
'''
|
||||
Extract appropriate sequence lengths from attention metadata
|
||||
according to attention type.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
* Appropriate sequence lengths tensor for query
|
||||
* Appropriate sequence lengths tensor for key & value
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
seq_lens_q = self.seq_lens
|
||||
seq_lens_kv = self.seq_lens
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
seq_lens_q = self.encoder_seq_lens
|
||||
seq_lens_kv = self.encoder_seq_lens
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
seq_lens_q = self.seq_lens
|
||||
seq_lens_kv = self.encoder_seq_lens
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
return seq_lens_q, seq_lens_kv
|
||||
|
||||
def get_attn_bias(
|
||||
self,
|
||||
attn_type: AttentionType,
|
||||
) -> Optional[List[torch.Tensor]]:
|
||||
'''
|
||||
Extract appropriate attention bias from attention metadata
|
||||
according to attention type.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
* Appropriate attention bias value given the attention type
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
return self.attn_bias
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
return self.encoder_attn_bias
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
return self.cross_attn_bias
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
def set_attn_bias(
|
||||
self,
|
||||
attn_bias: List[torch.Tensor],
|
||||
attn_type: AttentionType,
|
||||
) -> None:
|
||||
'''
|
||||
Update appropriate attention bias field of attention metadata,
|
||||
according to attention type.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* attn_bias: The desired attention bias value
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
self.attn_bias = attn_bias
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
self.encoder_attn_bias = attn_bias
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
self.cross_attn_bias = attn_bias
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
def get_seq_len_block_table_args(
|
||||
self,
|
||||
attn_type: str,
|
||||
) -> tuple:
|
||||
'''
|
||||
The particular choice of sequence-length- and block-table-related
|
||||
attributes which should be extracted from attn_metadata is dependent
|
||||
on the type of attention operation.
|
||||
|
||||
Decoder attn -> select entirely decoder self-attention-related fields
|
||||
Encoder/decoder cross-attn -> select encoder sequence lengths &
|
||||
cross-attn block-tables fields
|
||||
Encoder attn -> select encoder sequence lengths fields & no block tables
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* is_prompt: True if prefill, False otherwise
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
|
||||
* Appropriate sequence-lengths tensor
|
||||
* Appropriate max sequence-length scalar
|
||||
* Appropriate block tables (or None)
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
# Decoder self-attention
|
||||
# Choose max_seq_len based on whether we are in prompt_run
|
||||
return (self.seq_lens_tensor, self.max_decode_seq_len,
|
||||
self.block_tables)
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Enc/dec cross-attention KVs match encoder sequence length;
|
||||
# cross-attention utilizes special "cross" block tables
|
||||
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
|
||||
self.cross_block_tables)
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
# No block tables associated with encoder attention
|
||||
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
|
||||
None)
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
class VACCMetadataBuilder(AttentionMetadataBuilder[VACCAttentionMetadata]):
|
||||
|
||||
def __init__(self, input_builder: ModelInputForVACCBuilder) -> None:
|
||||
self.chunked_prefill = input_builder.chunked_prefill
|
||||
self.input_builder = input_builder
|
||||
|
||||
def prepare(self):
|
||||
self.input_data = self.input_builder.input_data
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int) -> VACCAttentionMetadata:
|
||||
input_data = self.input_data
|
||||
prefill_seq_lens = seq_lens[0:input_data.num_prefills]
|
||||
prefill_query_lens = query_lens[0:input_data.num_prefills]
|
||||
slot_mapping = torch.tensor(input_data.slot_mapping,
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device)
|
||||
|
||||
# For chunked-prefill
|
||||
if self.chunked_prefill and input_data.num_prefill_tokens != 0:
|
||||
prefill_block_tables = make_tensor_with_pad(
|
||||
self.input_data.prefill_block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device,
|
||||
)
|
||||
query_lens_tensor = torch.tensor(prefill_query_lens,
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device)
|
||||
kv_lens_tensor = torch.tensor(prefill_seq_lens,
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device)
|
||||
query_start_loc = torch.zeros(input_data.num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device)
|
||||
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device)
|
||||
torch.cumsum(query_lens_tensor,
|
||||
dim=0,
|
||||
dtype=torch.int32,
|
||||
out=query_start_loc[1:])
|
||||
torch.cumsum(kv_lens_tensor,
|
||||
dim=0,
|
||||
dtype=torch.int32,
|
||||
out=kv_start_loc[1:])
|
||||
max_query_len = max(prefill_query_lens)
|
||||
max_kv_len = max(prefill_seq_lens)
|
||||
else:
|
||||
prefill_block_tables = None
|
||||
query_start_loc = None
|
||||
kv_start_loc = None
|
||||
max_query_len = None
|
||||
max_kv_len = None
|
||||
|
||||
# For paged attention
|
||||
if input_data.num_decode_tokens != 0:
|
||||
seq_lens_tensor = torch.tensor(
|
||||
input_data.seq_lens[input_data.num_prefills:],
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device,
|
||||
)
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.input_data.decode_block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device,
|
||||
)
|
||||
# lowest_dim_size = block_tables.size(-1)
|
||||
# if lowest_dim_size < 1024:
|
||||
# padding_amount = 1024 - lowest_dim_size
|
||||
# padding = torch.zeros(*block_tables.size()[:-1], padding_amount, dtype=block_tables.dtype, device=block_tables.device)
|
||||
# block_tables = torch.cat((block_tables, padding), dim=-1)
|
||||
else:
|
||||
block_tables = torch.tensor([])
|
||||
seq_lens_tensor = torch.tensor(
|
||||
input_data.seq_lens[:input_data.num_prefills],
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device,
|
||||
)
|
||||
|
||||
# For multi-modal models
|
||||
placeholder_index_maps = None
|
||||
if len(input_data.multi_modal_inputs_list) != 0:
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
input_data.multi_modal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
attn_metadata = VACCAttentionMetadata(
|
||||
chunked_prefill=self.chunked_prefill,
|
||||
seq_lens=seq_lens, #prefill_seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_kv_len=max_kv_len,
|
||||
query_start_loc=query_start_loc,
|
||||
kv_start_loc=kv_start_loc,
|
||||
max_decode_seq_len=None,
|
||||
num_prefills=input_data.num_prefills,
|
||||
num_prefill_tokens=input_data.num_prefill_tokens,
|
||||
num_decode_tokens=input_data.num_decode_tokens,
|
||||
block_tables=block_tables,
|
||||
prefill_block_tables=prefill_block_tables,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=False,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
class VACCAttentionBackendImpl(AttentionImpl[VACCAttentionMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"Torch SPDA does not support block-sparse attention.")
|
||||
if logits_soft_cap is not None:
|
||||
logger.warning_once("Torch SPDA does not support logits soft cap. "
|
||||
"Outputs may be slightly off.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.sliding_window = sliding_window
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.need_mask = (self.alibi_slopes is not None
|
||||
or self.sliding_window is not None)
|
||||
|
||||
supported_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {supported_head_sizes}.")
|
||||
if kv_cache_dtype != "auto":
|
||||
raise NotImplementedError(
|
||||
"Torch SDPA backend does not support FP8 KV cache. "
|
||||
"Please use xFormers backend instead.")
|
||||
self.attn_type = attn_type
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: VACCAttentionMetadata, # type: ignore
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with torch SDPA and PagedAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
attn_type = self.attn_type
|
||||
if (attn_type == AttentionType.ENCODER
|
||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||
raise AttributeError("Encoder attention requires setting "
|
||||
"encoder metadata attributes.")
|
||||
elif (attn_type == AttentionType.ENCODER_DECODER
|
||||
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
||||
raise AttributeError("Encoder/decoder cross-attention "
|
||||
"requires setting cross-attention "
|
||||
"metadata attributes.")
|
||||
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
assert value is not None
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
else:
|
||||
assert value is None
|
||||
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
|
||||
# KV-cache during decoder-self- or
|
||||
# encoder-decoder-cross-attention, but not
|
||||
# during encoder attention.
|
||||
#
|
||||
# Even if there are no new key/value pairs to cache,
|
||||
# we still need to break out key_cache and value_cache
|
||||
# i.e. for later use by paged attention
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
|
||||
if (key is not None) and (value is not None):
|
||||
if attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Update cross-attention KV cache (prefill-only)
|
||||
# During cross-attention decode, key & value will be None,
|
||||
# preventing this IF-statement branch from running
|
||||
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
||||
else:
|
||||
# Update self-attention KV cache (prefill/decode)
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale, layer._v_scale)
|
||||
|
||||
if attn_type != AttentionType.ENCODER:
|
||||
# Decoder self-attention supports chunked prefill.
|
||||
# Encoder/decoder cross-attention requires no chunked
|
||||
# prefill (100% prefill or 100% decode tokens, no mix)
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
else:
|
||||
# Encoder attention - chunked prefill is not applicable;
|
||||
# derive token-count from query shape & and treat them
|
||||
# as 100% prefill tokens
|
||||
assert attn_metadata.num_encoder_tokens is not None
|
||||
num_prefill_tokens = attn_metadata.num_encoder_tokens
|
||||
num_decode_tokens = 0
|
||||
|
||||
if attn_type == AttentionType.DECODER:
|
||||
# Only enforce this shape-constraint for decoder
|
||||
# self-attention
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
output = torch.empty_like(query)
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
if (kv_cache.numel() == 0
|
||||
or prefill_meta.block_tables.numel() == 0):
|
||||
self._run_vacc_forward(
|
||||
output,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
prefill_meta,
|
||||
attn_type=attn_type)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
assert not self.need_mask
|
||||
import intel_extension_for_pytorch.llm.modules as ipex_modules
|
||||
output = torch.empty_like(query)
|
||||
ipex_modules.PagedAttention.flash_attn_varlen_func(
|
||||
output[:prefill_meta.num_prefill_tokens, :, :],
|
||||
query[:prefill_meta.num_prefill_tokens, :, :],
|
||||
key_cache,
|
||||
value_cache,
|
||||
prefill_meta.query_start_loc,
|
||||
prefill_meta.kv_start_loc,
|
||||
prefill_meta.max_query_len,
|
||||
prefill_meta.max_kv_len,
|
||||
self.scale,
|
||||
True,
|
||||
prefill_meta.prefill_block_tables,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||
"Encoder-only models should not have decode metadata.")
|
||||
# Decoding run.
|
||||
# (
|
||||
# seq_lens_arg,
|
||||
# max_seq_len_arg,
|
||||
# block_tables_arg,
|
||||
# ) = decode_meta.get_seq_len_block_table_args(attn_type)
|
||||
|
||||
# Note:
|
||||
# decode attention still use SDPA method
|
||||
# reshape k/v_cache to (num_block_grp, block_grp_size, head, hidden_size)
|
||||
k_cache = key_cache.view(-1, env_blk_grp_size, key_cache.shape[2], key_cache.shape[3])
|
||||
v_cache = value_cache.view(-1, env_blk_grp_size, value_cache.shape[2], value_cache.shape[3])
|
||||
block_per_group = env_blk_grp_size // 16
|
||||
# convert block_tables to 8K group index
|
||||
block_tables = (decode_meta.block_tables // block_per_group).to(torch.int32)
|
||||
attn_outs = []
|
||||
for i in range(decode_meta.seq_lens_tensor.shape[0]):
|
||||
seq_len = decode_meta.seq_lens_tensor[i]
|
||||
k_slices = k_cache[block_tables[i], ...]
|
||||
k = \
|
||||
torch.cat([k_slices[i, ...] for i in range(len(block_tables[i]))], dim=0)[:seq_len]
|
||||
v_slices = v_cache[block_tables[i], ...]
|
||||
v = \
|
||||
torch.cat([v_slices[i, ...] for i in range(len(block_tables[i]))], dim=0)[:seq_len]
|
||||
q = query[i : i + 1, ...]
|
||||
attn_out = torch.vacc.scaled_dot_product_attention(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
attn_mask=None,
|
||||
dropout_p=0,
|
||||
is_causal=False,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,
|
||||
sm_scale=self.scale,
|
||||
)
|
||||
attn_outs.append(attn_out)
|
||||
output = torch.cat(attn_outs, dim=0)
|
||||
# '''
|
||||
|
||||
# PagedAttention.forward_decode(
|
||||
# output[attn_metadata.num_prefill_tokens:, :, :],
|
||||
# query[attn_metadata.num_prefill_tokens:, :, :],
|
||||
# key_cache,
|
||||
# value_cache,
|
||||
# block_tables_arg,
|
||||
# seq_lens_arg,
|
||||
# max_seq_len_arg,
|
||||
# self.kv_cache_dtype,
|
||||
# self.num_kv_heads,
|
||||
# self.scale,
|
||||
# self.alibi_slopes,
|
||||
# layer._k_scale,
|
||||
# layer._v_scale,
|
||||
# )
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
def _run_vacc_forward(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: VACCAttentionMetadata,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
):
|
||||
# if self.num_kv_heads != self.num_heads:
|
||||
# key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||
# value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||
attn_masks = attn_metadata.get_attn_bias(attn_type)
|
||||
if attn_masks is None:
|
||||
if self.alibi_slopes is not None:
|
||||
attn_masks = _make_alibi_bias(
|
||||
self.alibi_slopes, query.dtype,
|
||||
attn_metadata.seq_lens) # type: ignore
|
||||
elif self.sliding_window is not None:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
attn_masks = _make_sliding_window_bias(
|
||||
attn_metadata.seq_lens, self.sliding_window,
|
||||
query.dtype) # type: ignore
|
||||
else:
|
||||
seq_lens, _ = attn_metadata.get_seq_lens(attn_type)
|
||||
attn_masks = [None] * len(seq_lens)
|
||||
attn_metadata.set_attn_bias(attn_masks, attn_type)
|
||||
|
||||
causal_attn = (attn_type == AttentionType.DECODER)
|
||||
|
||||
seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
|
||||
start_q, start_kv = 0, 0
|
||||
for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv,
|
||||
attn_masks):
|
||||
end_q = start_q + seq_len_q
|
||||
end_kv = start_kv + seq_len_kv
|
||||
sub_out=torch.vacc.scaled_dot_product_attention(
|
||||
query[start_q:end_q,:, :],
|
||||
key[start_kv:end_kv,:, :],
|
||||
value[start_kv:end_kv,:, :].contiguous(),
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=True, #causal_attn and not self.need_mask,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,
|
||||
sm_scale=self.scale)
|
||||
output[ start_q:end_q,:, :] = sub_out
|
||||
start_q, start_kv = end_q, end_kv
|
||||
return output
|
||||
|
||||
|
||||
def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
seq_lens: List[int],
|
||||
) -> List[torch.Tensor]:
|
||||
attn_biases: List[torch.Tensor] = []
|
||||
for seq_len in seq_lens:
|
||||
bias = torch.arange(seq_len, dtype=dtype)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(seq_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = bias[None, :].repeat((num_heads, 1, 1))
|
||||
bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
|
||||
inf_mask = torch.empty(
|
||||
(1, seq_len, seq_len),
|
||||
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
|
||||
attn_biases.append((bias + inf_mask).to(dtype))
|
||||
|
||||
return attn_biases
|
||||
|
||||
|
||||
def _make_sliding_window_bias(
|
||||
seq_lens: List[int],
|
||||
window_size: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
) -> List[torch.Tensor]:
|
||||
attn_biases: List[torch.Tensor] = []
|
||||
for seq_len in seq_lens:
|
||||
tensor = torch.full(
|
||||
(1, seq_len, seq_len),
|
||||
dtype=dtype,
|
||||
fill_value=1,
|
||||
)
|
||||
shift = 0
|
||||
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
|
||||
if window_size is not None:
|
||||
mask = torch.triu(mask, diagonal=shift - window_size + 1)
|
||||
mask = torch.log(mask)
|
||||
attn_biases.append(mask.to(dtype))
|
||||
|
||||
return attn_biases
|
||||
847
vllm_vacc/vllm/attention/backends/vacc_mla.py
Normal file
847
vllm_vacc/vllm/attention/backends/vacc_mla.py
Normal file
@@ -0,0 +1,847 @@
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
|
||||
try:
|
||||
from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||
except ImportError:
|
||||
BatchDecodeMlaWithPagedKVCacheWrapper = None
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionState, AttentionType)
|
||||
from vllm.attention.backends.mla.common import MLACommonImpl, MLACommonMetadata
|
||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
|
||||
#from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm_vacc.vllm.attention.ops.vacc_paged_attn import VaccPagedAttention as PagedAttention
|
||||
# from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
# import time, os
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm_vacc.vllm.worker.vacc_model_runner import (ModelInputForVACCBuilder,
|
||||
ModelInputForVACCWithSamplingMetadata)
|
||||
|
||||
|
||||
class VACCMLABackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TORCH_VACC"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["VACCMLAImpl"]:
|
||||
return VACCMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return VACCMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["VACCMLAMetadataBuilder"]:
|
||||
return VACCMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["VACCMLAState"]:
|
||||
return VACCMLAState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [576]
|
||||
|
||||
|
||||
class VACCMLAState(AttentionState):
|
||||
|
||||
def __init__(self, runner):
|
||||
self.runner = runner
|
||||
self._is_graph_capturing = False
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
self._is_graph_capturing = True
|
||||
|
||||
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
self._graph_seq_lens = torch.ones(max_batch_size,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
self._graph_block_tables = torch.from_numpy(
|
||||
self.runner.graph_block_tables).to(device=self.runner.device)
|
||||
|
||||
self._positions = torch.zeros((max_batch_size, ),
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
|
||||
yield
|
||||
|
||||
self._is_graph_capturing = False
|
||||
del self._graph_slot_mapping
|
||||
del self._graph_seq_lens
|
||||
del self._graph_block_tables
|
||||
del self._positions
|
||||
|
||||
def graph_clone(self, batch_size: int):
|
||||
assert self._is_graph_capturing
|
||||
return self.__class__(self.runner)
|
||||
|
||||
def graph_capture_get_metadata_for_batch(
|
||||
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
||||
assert self._is_graph_capturing
|
||||
|
||||
attn_metadata = self.runner.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size,
|
||||
slot_mapping=self._graph_slot_mapping[:batch_size],
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
||||
# max_query_len=1,
|
||||
# max_decode_query_len=1,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.runner.max_seq_len_to_capture,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=self._graph_block_tables[:batch_size],
|
||||
use_cuda_graph=True,
|
||||
input_positions=self._positions[:batch_size],
|
||||
head_dim=self.runner.model_config.get_head_size())
|
||||
|
||||
if is_encoder_decoder_model:
|
||||
raise NotImplementedError(
|
||||
"VACCMLAState does not support encoder/decoder yet")
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def get_graph_input_buffers(self,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
input_buffers = {
|
||||
"slot_mapping": attn_metadata.slot_mapping,
|
||||
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
|
||||
"block_tables": attn_metadata.decode_metadata.block_tables,
|
||||
"input_positions": attn_metadata.decode_metadata.input_positions,
|
||||
}
|
||||
if is_encoder_decoder_model:
|
||||
raise NotImplementedError(
|
||||
"VACCMLAState does not support encoder/decoder yet")
|
||||
|
||||
return input_buffers
|
||||
|
||||
def prepare_graph_input_buffers(self,
|
||||
input_buffers,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
input_positions = attn_metadata.input_positions
|
||||
num_positions = input_positions.shape[0]
|
||||
input_buffers["seq_lens_tensor"].copy_(
|
||||
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
|
||||
input_buffers["block_tables"].copy_(
|
||||
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
||||
# CUDA graph buffer is padded so only perform a partial copy based on
|
||||
# num_positions
|
||||
input_buffers["input_positions"][:num_positions].copy_(
|
||||
input_positions, non_blocking=True)
|
||||
if is_encoder_decoder_model:
|
||||
raise NotImplementedError(
|
||||
"VACCMLAState does not support encoder/decoder yet")
|
||||
|
||||
def begin_forward(self, model_input):
|
||||
return
|
||||
|
||||
|
||||
@dataclass
|
||||
class VACCMLAMetadata(MLACommonMetadata):
|
||||
"""Metadata for VACCMLAMetadata.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
cuda-graph replayed. If you have values that need to be changed
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
|
||||
use_cuda_graph: bool
|
||||
|
||||
# Maximum query length in the batch.
|
||||
max_query_len: Optional[int] = None
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
# Max number of query tokens among request in the batch.
|
||||
max_decode_query_len: Optional[int] = None
|
||||
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
_cached_prefill_metadata: Optional["VACCMLAMetadata"] = None
|
||||
_cached_decode_metadata: Optional["VACCMLAMetadata"] = None
|
||||
|
||||
num_prefill_tokens: int
|
||||
|
||||
num_kv_splits: int = 4 # TODO(lucas) add heuristic
|
||||
attn_logits: Optional[torch.Tensor] = None
|
||||
req_idx: Optional[torch.Tensor] = None
|
||||
|
||||
# The dimension of the attention heads
|
||||
head_dim: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
supported_head_sizes = VACCMLABackend.get_supported_head_sizes()
|
||||
if self.head_dim is not None and self.head_dim \
|
||||
not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
f"received {self.head_dim}.")
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["VACCMLAMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
query_start_loc = (None if self.query_start_loc is None else
|
||||
self.query_start_loc[:self.num_prefills + 1])
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[:self.num_prefill_tokens])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[:self.num_prefills])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[:self.num_prefills])
|
||||
seq_start_loc = (None if self.seq_start_loc is None else
|
||||
self.seq_start_loc[:self.num_prefills + 1])
|
||||
context_lens_tensor = (None if self.context_lens_tensor is None else
|
||||
self.context_lens_tensor[:self.num_prefills])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[:self.num_prefills])
|
||||
input_positions = (None if self.input_positions is None else
|
||||
self.input_positions[:self.num_prefill_tokens])
|
||||
|
||||
self._cached_prefill_metadata = VACCMLAMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
||||
input_positions=input_positions,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_prefill_seq_len=None,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
head_dim=self.head_dim)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["VACCMLAMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
return self._cached_decode_metadata
|
||||
assert self.seq_lens_tensor is not None
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[self.num_prefill_tokens:])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[self.num_prefills:])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[self.num_prefills:])
|
||||
input_positions = (None if self.input_positions is None else
|
||||
self.input_positions[self.num_prefill_tokens:])
|
||||
|
||||
self._cached_decode_metadata = VACCMLAMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens=self.seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_decode_query_len=self.max_decode_query_len,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
# Batch may be composed of prefill|decodes, adjust query start
|
||||
# indices to refer to the start of decodes. E.g.
|
||||
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
||||
query_start_loc=(self.query_start_loc[self.num_prefills:] -
|
||||
self.query_start_loc[self.num_prefills])
|
||||
if self.query_start_loc is not None else None,
|
||||
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
||||
if self.seq_start_loc is not None else None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
input_positions=input_positions,
|
||||
head_dim=self.head_dim)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
def advance_step(self,
|
||||
model_input: "ModelInputForVACCWithSamplingMetadata",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int,
|
||||
num_seqs: int,
|
||||
num_queries: int,
|
||||
turn_prefills_into_decodes: bool = False):
|
||||
"""
|
||||
Update metadata in-place to advance one decode step.
|
||||
"""
|
||||
# When using cudagraph, the num_seqs is padded to the next captured
|
||||
# batch sized, but num_queries tracks the actual number of requests in
|
||||
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
||||
if num_seqs != num_queries:
|
||||
assert num_seqs > num_queries
|
||||
assert self.use_cuda_graph
|
||||
|
||||
if turn_prefills_into_decodes:
|
||||
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
|
||||
# decodes are scheduled together. In the first step, all the
|
||||
# prefills turn into decodes. This update reflects that
|
||||
# conversion.
|
||||
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
||||
self.num_decode_tokens += self.num_prefills
|
||||
self.num_prefills = 0
|
||||
# self.num_prefill_tokens = 0
|
||||
# self.max_prefill_seq_len = 0
|
||||
self.max_query_len = 1
|
||||
|
||||
self.slot_mapping = self.slot_mapping[:num_seqs]
|
||||
else:
|
||||
assert self.seq_lens is not None
|
||||
assert self.max_decode_seq_len == max(self.seq_lens)
|
||||
|
||||
assert self.num_prefills == 0
|
||||
assert self.num_prefill_tokens == 0
|
||||
assert self.num_decode_tokens == num_seqs
|
||||
assert self.slot_mapping.shape == (num_seqs, )
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert len(self.seq_lens) == num_seqs
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.seq_lens_tensor.shape == (num_seqs, )
|
||||
# assert self.max_query_len == 1
|
||||
# assert self.max_prefill_seq_len == 0
|
||||
|
||||
assert self.query_start_loc is not None
|
||||
assert self.query_start_loc.shape == (num_queries + 1, )
|
||||
assert self.seq_start_loc is not None
|
||||
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
||||
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.context_lens_tensor.shape == (num_queries, )
|
||||
|
||||
assert self.block_tables is not None
|
||||
assert self.block_tables.shape[0] == num_seqs
|
||||
|
||||
# Update query lengths. Note that we update only queries and not seqs,
|
||||
# since tensors may be padded due to captured cuda graph batch size
|
||||
for i in range(num_queries):
|
||||
self.seq_lens[i] += 1
|
||||
# self.max_decode_seq_len = None
|
||||
|
||||
ops.advance_step_flashattn(num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=block_size,
|
||||
input_tokens=model_input.input_tokens,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
input_positions=model_input.input_positions,
|
||||
seq_lens=self.seq_lens_tensor,
|
||||
slot_mapping=self.slot_mapping,
|
||||
block_tables=self.block_tables)
|
||||
|
||||
|
||||
class VACCMLAMetadataBuilder(AttentionMetadataBuilder[VACCMLAMetadata]):
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForVACCBuilder"):
|
||||
self.chunked_prefill = True
|
||||
if hasattr(input_builder, 'chunked_prefill'):
|
||||
self.chunked_prefill = input_builder.chunked_prefill
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
def prepare(self):
|
||||
self.slot_mapping: List[int] = []
|
||||
self.prefill_seq_lens: List[int] = []
|
||||
self.context_lens: List[int] = []
|
||||
self.block_tables: List[List[int]] = []
|
||||
self.curr_seq_lens: List[int] = []
|
||||
self.input_positions: List[int] = []
|
||||
self.multimodal_placeholder_maps: Dict[
|
||||
str,
|
||||
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
self.has_prefix_cache_hit = False
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
"""Build attention metadata with on-device tensors.
|
||||
|
||||
Args:
|
||||
seq_lens: The maybe padded sequence lengths of the input sequences.
|
||||
query_lens: The query lengths of the input sequences.
|
||||
cuda_graph_pad_size: The padding size for cuda graph.
|
||||
-1 if cuda graph is not used.
|
||||
batch_size: The maybe padded batch size.
|
||||
"""
|
||||
|
||||
self.input_data = self.input_builder.input_data
|
||||
|
||||
self.slot_mapping=self.input_data.slot_mapping
|
||||
self.context_lens= self.input_data.context_lens
|
||||
if self.input_data.num_prefill_tokens !=0:
|
||||
|
||||
self.block_tables = self.input_data.prefill_block_tables
|
||||
else:
|
||||
self.block_tables= self.input_data.decode_block_tables
|
||||
self.input_positions= self.input_data.input_positions
|
||||
|
||||
self.prefill_seq_lens = seq_lens[0:self.input_data.num_prefills]
|
||||
|
||||
self.num_prefills = self.input_data.num_prefills
|
||||
self.num_prefill_tokens = self.input_data.num_prefill_tokens
|
||||
self.num_decode_tokens = self.input_data.num_decode_tokens
|
||||
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
# max_query_len = max(query_lens)
|
||||
# decode_query_lens = query_lens[self.num_prefills:]
|
||||
# if len(decode_query_lens) > 0:
|
||||
# max_decode_query_len = max(decode_query_lens)
|
||||
# else:
|
||||
# max_decode_query_len = 1
|
||||
# max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
# max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
query_start_loc = list(accumulate(query_lens, initial=0))
|
||||
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
||||
|
||||
num_seqs = len(seq_lens)
|
||||
if use_captured_graph:
|
||||
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
||||
self.block_tables.extend([] * cuda_graph_pad_size)
|
||||
num_decode_tokens = batch_size - self.num_prefill_tokens
|
||||
block_tables = self._get_graph_runner_block_tables(
|
||||
num_seqs, self.block_tables)
|
||||
else:
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
)
|
||||
# assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
||||
|
||||
assert device is not None
|
||||
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
||||
device, self.runner.pin_memory)
|
||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||
self.runner.pin_memory)
|
||||
input_positions = async_tensor_h2d(self.input_positions, torch.int,
|
||||
device, self.runner.pin_memory)
|
||||
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int,
|
||||
device, self.runner.pin_memory)
|
||||
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
||||
device,
|
||||
self.runner.pin_memory)
|
||||
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
||||
device, self.runner.pin_memory)
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
self.multimodal_placeholder_maps.items()
|
||||
}
|
||||
return VACCMLAMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=True,
|
||||
input_positions=input_positions,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
# max_query_len=max_query_len,
|
||||
# max_decode_query_len=None,
|
||||
max_prefill_seq_len=None,
|
||||
max_decode_seq_len=None,
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
seq_start_loc=seq_start_loc_tensor,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
num_kv_splits=4, # TODO(lucas) add heuristic
|
||||
head_dim=self.runner.model_config.get_head_size(),
|
||||
)
|
||||
|
||||
|
||||
class VACCMLAImpl(MLACommonImpl[VACCMLAMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**kwargs) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **kwargs)
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"VACCMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"VACCMLAImpl")
|
||||
|
||||
def extract_weights(self):
|
||||
weights = {}
|
||||
if hasattr(self, 'W_Q'):
|
||||
weights["W_Q"] = self.W_Q
|
||||
if hasattr(self, 'W_Q_scales'):
|
||||
weights["W_Q_scales"] = self.W_Q_scales
|
||||
if hasattr(self, 'W_QR'):
|
||||
weights['W_QR'] = self.W_QR
|
||||
if hasattr(self, 'W_QR_scales'):
|
||||
weights["W_QR_scales"] = self.W_QR_scales
|
||||
if hasattr(self, 'W_Q_QR'):
|
||||
weights["W_Q_QR"] = self.W_Q_QR
|
||||
if hasattr(self, 'W_Q_QR_scales'):
|
||||
weights["W_Q_QR_scales"] = self.W_Q_QR_scales
|
||||
if hasattr(self, 'W_UK'):
|
||||
weights['W_UK'] = self.W_UK
|
||||
if hasattr(self, 'W_UK_scales'):
|
||||
weights['W_UK_scales'] = self.W_UK_scales
|
||||
if hasattr(self, 'W_Q_UK_scales'):
|
||||
weights['W_Q_UK_scales'] = self.W_Q_UK_scales
|
||||
if hasattr(self, 'W_UV'):
|
||||
weights['W_UV'] = self.W_UV
|
||||
if hasattr(self, 'W_UV_scales'):
|
||||
weights['W_UV_scales'] = self.W_UV_scales
|
||||
if hasattr(self, 'W_UV_O'):
|
||||
weights['W_UV_O'] = self.W_UV_O
|
||||
if hasattr(self, 'W_UV_O_scales'):
|
||||
weights['W_UV_O_scales'] = self.W_UV_O_scales
|
||||
return weights
|
||||
|
||||
def _forward_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: VACCMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(attn_metadata, VACCMLAMetadata)
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0]\
|
||||
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv_nope\
|
||||
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||
v = v.contiguous()
|
||||
|
||||
# For MLA the v head dim is smaller than qk head dim so we pad out
|
||||
# v with 0s to match the qk head dim
|
||||
# v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
||||
# value=0)
|
||||
# attn_output = torch.vacc.scaled_dot_product_attention(
|
||||
# query=q,
|
||||
# key=k,
|
||||
# value=v_padded,
|
||||
# attn_mask=None,
|
||||
# dropout_p=0,
|
||||
# is_causal=True,
|
||||
# is_train=False,
|
||||
# recompute=False,
|
||||
# flash_attention=True,
|
||||
# sm_scale=self.scale
|
||||
# )
|
||||
|
||||
# attn_output = attn_output\
|
||||
# .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
|
||||
# .reshape(-1, self.num_heads * v.shape[-1])
|
||||
seq_lens = attn_metadata.prefill_metadata.seq_lens
|
||||
if len(seq_lens) == 1:
|
||||
# Vacc supports different head dim of v and qk.
|
||||
attn_output = torch.vacc.scaled_dot_product_attention(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
attn_mask=None,
|
||||
dropout_p=0,
|
||||
is_causal=True,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,
|
||||
sm_scale=self.scale
|
||||
)
|
||||
attn_out = attn_output.view(-1, self.num_heads * v.shape[-1])
|
||||
else:
|
||||
attn_outs = []
|
||||
start = 0
|
||||
for seq in seq_lens:
|
||||
end = start + seq
|
||||
attn_out = torch.vacc.scaled_dot_product_attention(
|
||||
query=q[start:end, :],
|
||||
key=k[start:end, :],
|
||||
value=v[start:end, :],
|
||||
attn_mask=None,
|
||||
dropout_p=0,
|
||||
is_causal=True,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,
|
||||
sm_scale=self.scale
|
||||
)
|
||||
start = end
|
||||
attn_outs.append(attn_out)
|
||||
attn_out = torch.cat(attn_outs, dim=0).view(-1, self.num_heads * v.shape[-1])
|
||||
|
||||
return self.o_proj(attn_out)[0]
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: VACCMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
||||
|
||||
decode_meta = attn_metadata.decode_metadata
|
||||
assert decode_meta is not None
|
||||
B = q_nope.shape[0]
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
o = torch.zeros(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
|
||||
# Add a head dim of 1
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
# print(f"kv_c_and_k_pe_cache: {kv_c_and_k_pe_cache.shape} ")
|
||||
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
|
||||
|
||||
# Run MQA using paged_attention
|
||||
# o = torch.vacc.paged_attention(
|
||||
# query=q,
|
||||
# key_cache=kv_c_and_k_pe_cache,
|
||||
# value_cache=kv_c_cache,
|
||||
# block_table=decode_meta.block_tables,
|
||||
# seq_len=decode_meta.seq_lens_tensor,
|
||||
# out=o,
|
||||
# sm_scale=self.scale
|
||||
# )
|
||||
|
||||
# Run MQA using spda
|
||||
# t0 = time.time()
|
||||
o = vacc_paged_attention_naive(
|
||||
q,
|
||||
kv_c_and_k_pe_cache,
|
||||
kv_c_cache,
|
||||
block_table = decode_meta.block_tables,
|
||||
# seq_lens = decode_meta.seq_lens_tensor,
|
||||
seq_lens=decode_meta.seq_lens,
|
||||
out = o,
|
||||
sm_scale=self.scale)
|
||||
# print(f'{os.getpid()} paged_atten(seq: {decode_meta.seq_lens}) time: {time.time() - t0}')
|
||||
|
||||
return self._v_up_proj_and_o_proj(o)
|
||||
|
||||
def vacc_paged_attention_naive(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
# seq_lens: torch.Tensor,
|
||||
seq_lens: int,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
sm_scale = -1
|
||||
) -> torch.Tensor:
|
||||
|
||||
# gurantee batch=1 perf
|
||||
if len(seq_lens) == 1:
|
||||
k = key_cache.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[0]]
|
||||
v = value_cache.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[0]]
|
||||
attn_out = torch.vacc.scaled_dot_product_attention(
|
||||
query=query,
|
||||
key=k,
|
||||
value=v,
|
||||
attn_mask=None,
|
||||
dropout_p=0,
|
||||
is_causal=False,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,
|
||||
sm_scale=sm_scale
|
||||
)
|
||||
else:
|
||||
# t0 = time.time()
|
||||
attn_outs = []
|
||||
for i in range(len(seq_lens)):
|
||||
k_slices = key_cache[block_table[i], :, :, :]
|
||||
k = torch.cat([k_slices[i, :, :, :].unsqueeze(1) for i in range(len(block_table[i]))], dim=0)
|
||||
k = k.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[i]]
|
||||
v_slices = value_cache[block_table[i], :, :, :]
|
||||
v = torch.cat([v_slices[i, :, :, :].unsqueeze(1) for i in range(len(block_table[i]))], dim=0)
|
||||
v = v.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[i]]
|
||||
|
||||
attn_out = torch.vacc.scaled_dot_product_attention(
|
||||
query=query[i:i+1,:,:],
|
||||
key=k,
|
||||
value=v,
|
||||
attn_mask=None,
|
||||
dropout_p=0,
|
||||
is_causal=False,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,
|
||||
sm_scale=sm_scale
|
||||
)
|
||||
attn_outs.append(attn_out)
|
||||
|
||||
attn_out = torch.cat(attn_outs, dim=0)
|
||||
# print(f'{os.getpid()} call spda(seq: {seq_lens}) time: {time.time() - t0}')
|
||||
return attn_out
|
||||
|
||||
# MLA single op impl
|
||||
def vacc_paged_attention_naive_singleop(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
seq_lens,
|
||||
block_table = None,
|
||||
out: torch.Tensor = None,
|
||||
sm_scale = -1
|
||||
) -> torch.Tensor:
|
||||
k = key_cache.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens]
|
||||
v = value_cache.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens].squeeze(1)
|
||||
pe_cache = k[..., 512:].squeeze(1)
|
||||
print(f'q:{query[..., :512].shape} v:{v.shape} pe_cache:{pe_cache.shape}')
|
||||
q_nope_kv_c = torch.einsum("shc,tc->sht", query[..., :512], v)
|
||||
q_pe_k_pe = torch.einsum("shr,tr->sht", query[..., 512:], pe_cache)
|
||||
scores = (q_nope_kv_c + q_pe_k_pe) * sm_scale
|
||||
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(query)
|
||||
o = torch.einsum("sht,tc->shc", scores, v)
|
||||
return o
|
||||
Reference in New Issue
Block a user