Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -12,7 +12,7 @@ from vllm.config.vllm import VllmConfig
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.kv_transfer_utils import (
|
||||
maybe_transfer_kv_layer,
|
||||
maybe_transfer_kv_layer
|
||||
)
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||
@@ -40,6 +40,9 @@ from vllm.v1.kv_cache_interface import (
|
||||
KVCacheSpec,
|
||||
SlidingWindowSpec,
|
||||
)
|
||||
from .extra_cache import StaticQuantManager
|
||||
from ixformer.core import config
|
||||
_USE_TORCH_OPS = config.IXFORMER_USE_TORCH_OPS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.attention import MLAAttention
|
||||
@@ -202,6 +205,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
attn_backend: type[AttentionBackend] | None = None,
|
||||
head_size_v: int | None = None,
|
||||
extra_cache_para: dict = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -258,6 +262,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.hidden_size = head_size * num_heads
|
||||
self.head_size_v = self.head_size if head_size_v is None else head_size_v
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
@@ -326,6 +331,15 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
kv_sharing_target_layer_name,
|
||||
**extra_impl_args,
|
||||
)
|
||||
if extra_cache_para is not None:
|
||||
self.quant_manager = StaticQuantManager(
|
||||
layer_id=extra_cache_para.get("layer_id", None),
|
||||
shape=(self.num_kv_heads, self.head_size_v),
|
||||
dtype=torch.float32,
|
||||
total_layer_num=extra_cache_para.get("total_layer_num", None)
|
||||
)
|
||||
else:
|
||||
self.quant_manager = None
|
||||
self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
|
||||
self.dtype = dtype
|
||||
|
||||
@@ -333,7 +347,10 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
# torch.compile works by registering the attention as one giant
|
||||
# opaque custom op. For other platforms, we directly call them
|
||||
# and let torch.compile handle them.
|
||||
self.use_direct_call = not current_platform.opaque_attention_op()
|
||||
if _USE_TORCH_OPS:
|
||||
self.use_direct_call = False
|
||||
else:
|
||||
self.use_direct_call = True
|
||||
|
||||
self.use_output = self.attn_backend.accept_output_buffer
|
||||
compilation_config = vllm_config.compilation_config
|
||||
@@ -349,14 +366,26 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
compilation_config.static_forward_context,
|
||||
)
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
# use a placeholder kv cache tensor during init, which will be replaced
|
||||
# by bind_kv_cache
|
||||
# this variable will not be accessed if use_direct_call is True
|
||||
self.kv_cache = [
|
||||
torch.tensor([])
|
||||
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
self.is_i8qi8ki8v = envs.VLLM_ATTN_OPT_LEVEL == 1
|
||||
self.is_i8qi8kf16v = envs.VLLM_ATTN_OPT_LEVEL == 2
|
||||
if self.is_i8qi8kf16v:
|
||||
self.kv_cache_scale = [
|
||||
torch.tensor([]) for _ in range(get_current_vllm_config(
|
||||
).parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
elif self.is_i8qi8ki8v:
|
||||
self.kv_cache_scale = [
|
||||
[torch.tensor([]), torch.tensor([])] for _ in range(get_current_vllm_config(
|
||||
).parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
# use a placeholder kv cache tensor during init, which will be replaced
|
||||
# by bind_kv_cache
|
||||
# this variable will not be accessed if use_direct_call is True
|
||||
|
||||
# Initialize KV cache quantization attributes
|
||||
_init_kv_cache_quant(self, quant_config, prefix)
|
||||
@@ -396,6 +425,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
context using
|
||||
`vllm.forward_context.get_forward_context().attn_metadata`.
|
||||
"""
|
||||
optional_args = {}
|
||||
if self.calculate_kv_scales:
|
||||
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
|
||||
output_dtype = query.dtype
|
||||
@@ -412,15 +442,8 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
query, _ = self.query_quant(query, self._q_scale)
|
||||
|
||||
if self.use_output:
|
||||
if output_shape is None:
|
||||
# Handle both 2D [num_tokens, hidden] and
|
||||
# 3D [num_tokens, heads, head_dim] query
|
||||
num_tokens = query.shape[0]
|
||||
output_shape = torch.Size(
|
||||
(num_tokens, self.num_heads * self.head_size_v)
|
||||
)
|
||||
output_shape = output_shape if output_shape is not None else query.shape
|
||||
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
# Reshape the query, key, and value tensors.
|
||||
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||
# CPU overheads from the non-CUDA-graph regions.
|
||||
@@ -430,46 +453,50 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
if value is not None:
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size_v)
|
||||
kv_cache_dummy_dep = None
|
||||
if self.use_direct_call:
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if (
|
||||
not self.attn_backend.forward_includes_kv_cache_update
|
||||
and self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
kv_cache_dummy_dep = unified_kv_cache_update(
|
||||
key, value, self.layer_name
|
||||
def direct_forward(layer_name: str, output: torch.Tensor):
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if self.is_i8qi8ki8v or self.is_i8qi8kf16v:
|
||||
optional_args["kv_cache_scale"] = self.kv_cache_scale[forward_context.virtual_engine]
|
||||
output = self.impl.forward(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self_kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
**optional_args
|
||||
)
|
||||
unified_attention_with_output(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
self.layer_name,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
return output
|
||||
return maybe_transfer_kv_layer(direct_forward)(self.layer_name, output)
|
||||
else:
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if (
|
||||
not self.attn_backend.forward_includes_kv_cache_update
|
||||
and self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
|
||||
key, value, self.layer_name
|
||||
)
|
||||
if self.is_i8qi8ki8v:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
kv_cache_scale = self.kv_cache_scale[forward_context.virtual_engine][0]
|
||||
v_cache_scale = self.kv_cache_scale[forward_context.virtual_engine][1]
|
||||
elif self.is_i8qi8kf16v:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
kv_cache_scale = self.kv_cache_scale[forward_context.virtual_engine]
|
||||
v_cache_scale = None
|
||||
else:
|
||||
kv_cache_scale = None
|
||||
v_cache_scale = None
|
||||
torch.ops.vllm.unified_attention_with_output(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
self.layer_name,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
kv_cache_scale,
|
||||
v_cache_scale
|
||||
)
|
||||
return output.view(-1, hidden_size)
|
||||
return output.view(-1, self.hidden_size)
|
||||
else:
|
||||
assert self.attn_backend.forward_includes_kv_cache_update, (
|
||||
"Split KV cache update not supported when output tensor not provided."
|
||||
@@ -521,6 +548,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
# Should not be called for enc-dec or encoder-only attention.
|
||||
assert self.attn_type == AttentionType.DECODER
|
||||
# TODO : kernel unsupport kvcache for sliding_window, use FullAttentionSpec replace
|
||||
if self.sliding_window is not None:
|
||||
assert not vllm_config.model_config.use_mla, (
|
||||
"MLA is not supported for slidingwindow"
|
||||
@@ -689,6 +717,8 @@ def unified_attention_with_output(
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
kv_cache_scale: torch.Tensor | None = None,
|
||||
v_cache_scale: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
kv_cache_dummy_dep: torch.Tensor | None = None,
|
||||
@@ -696,9 +726,7 @@ def unified_attention_with_output(
|
||||
# kv_cache_dummy_dep is not used but accepting it creates a data dependency
|
||||
# that ensures torch.compile preserves ordering between KV cache update and
|
||||
# attention forward.
|
||||
del kv_cache_dummy_dep
|
||||
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
|
||||
|
||||
self.impl.forward(
|
||||
self,
|
||||
query,
|
||||
@@ -707,6 +735,7 @@ def unified_attention_with_output(
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
kv_cache_scale = [kv_cache_scale, v_cache_scale] if envs.VLLM_ATTN_OPT_LEVEL==1 else kv_cache_scale,
|
||||
output_scale=output_scale,
|
||||
output_block_scale=output_block_scale,
|
||||
)
|
||||
@@ -718,6 +747,8 @@ def unified_attention_with_output_fake(
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
kv_cache_scale: torch.Tensor | None = None,
|
||||
v_cache_scale: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
kv_cache_dummy_dep: torch.Tensor | None = None,
|
||||
|
||||
131
vllm/model_executor/layers/attention/extra_cache.py
Normal file
131
vllm/model_executor/layers/attention/extra_cache.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
|
||||
import torch
|
||||
from filelock import FileLock
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class StaticQuantManager:
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int,
|
||||
shape: tuple,
|
||||
dtype: torch.dtype,
|
||||
total_layer_num: int,
|
||||
device: str = None,
|
||||
tp_size: int = None,
|
||||
tp_rank: int = None,
|
||||
file_save_path: str = None,
|
||||
save_step: int = 100,
|
||||
info_step: int = 100,
|
||||
):
|
||||
# update parament
|
||||
if tp_size is None:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if tp_rank is None:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
if file_save_path is None:
|
||||
file_save_path = envs.VLLM_ATTN_STATIC_QUANT_SCALE_FILE_PATH
|
||||
if device is None:
|
||||
device = "cuda"
|
||||
|
||||
# check parament
|
||||
if file_save_path in [None, ""]:
|
||||
self.disable = True
|
||||
return
|
||||
|
||||
para_dir = os.path.dirname(file_save_path)
|
||||
assert os.path.exists(para_dir), (
|
||||
f"StaticQuantManager workdir {para_dir} not exist!"
|
||||
)
|
||||
self.disable = os.path.exists(file_save_path)
|
||||
if self.disable:
|
||||
return
|
||||
|
||||
assert layer_id is not None
|
||||
assert total_layer_num is not None
|
||||
|
||||
world_rank = torch.distributed.get_rank()
|
||||
work_dir = os.path.join(para_dir, "StaticQuantManagerWorkdir")
|
||||
self.operator = world_rank == 0 and layer_id == 0
|
||||
if not os.path.exists(work_dir):
|
||||
if self.operator:
|
||||
logger.debug(f"StaticQuantManager Creat {work_dir}!")
|
||||
os.mkdir(work_dir)
|
||||
self.file_save_path = file_save_path
|
||||
self.work_dir = work_dir
|
||||
|
||||
self.tp_size = tp_size
|
||||
self.tp_rank = tp_rank
|
||||
self.world_rank = world_rank
|
||||
self.layer_id = layer_id
|
||||
self.total_layer_num = total_layer_num
|
||||
self.save_step = save_step
|
||||
self.info_step = info_step
|
||||
self.update_count = 0
|
||||
self.save_flag = False
|
||||
self.scales = torch.zeros(shape, dtype=dtype, device=device)
|
||||
logger.debug(
|
||||
f"StaticQuantManager info: world_rank:{self.world_rank} tp_rank:{self.tp_rank} layer_id:{self.layer_id} scale shape:{shape} self.scales:{self.scales.device}"
|
||||
)
|
||||
|
||||
def check_enable(self):
|
||||
return not self.disable
|
||||
|
||||
def update_data(self, data):
|
||||
if self.disable:
|
||||
return
|
||||
|
||||
self.scales = torch.max(data, self.scales)
|
||||
|
||||
# save file
|
||||
self.update_count += 1
|
||||
if self.update_count % self.info_step == 0 and self.operator:
|
||||
logger.info(f"StaticQuantManager run update_data {self.update_count} step")
|
||||
|
||||
if self.update_count % self.save_step == 0:
|
||||
# step1: save to disk
|
||||
save_file_path = os.path.join(
|
||||
self.work_dir, f"{self.layer_id}_{self.tp_rank}.pt"
|
||||
)
|
||||
lock_file_path = os.path.join(
|
||||
self.work_dir, f"{self.layer_id}_{self.tp_rank}.lock"
|
||||
)
|
||||
lock = FileLock(lock_file_path)
|
||||
cpu_data = self.scales.cpu()
|
||||
with lock:
|
||||
torch.save(cpu_data, save_file_path)
|
||||
|
||||
# step2: merge and save
|
||||
if self.save_flag and self.operator:
|
||||
save_dict = {}
|
||||
for idx in range(self.total_layer_num):
|
||||
tp_datas = []
|
||||
for tp_rank in range(self.tp_size):
|
||||
load_file = os.path.join(self.work_dir, f"{idx}_{tp_rank}.pt")
|
||||
lock_file_path = os.path.join(
|
||||
self.work_dir, f"{idx}_{tp_rank}.lock"
|
||||
)
|
||||
lock = FileLock(lock_file_path)
|
||||
with lock:
|
||||
cur_data = torch.load(load_file)
|
||||
tp_datas.append(cur_data)
|
||||
|
||||
layer_data = torch.concat(tp_datas)
|
||||
save_dict[f"layer_{idx}"] = layer_data
|
||||
|
||||
torch.save(save_dict, self.file_save_path)
|
||||
logger.info(
|
||||
f"StaticQuantManager save to {self.file_save_path} with {self.update_count} step"
|
||||
)
|
||||
self.save_flag = True
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,21 +2,94 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
from vllm.utils.math_utils import round_up
|
||||
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.attention.ops.vit_attn_wrappers import (
|
||||
vit_flash_attn_wrapper,
|
||||
vit_flashinfer_wrapper,
|
||||
vit_torch_sdpa_wrapper,
|
||||
vit_triton_attn_wrapper,
|
||||
)
|
||||
import ixformer.contrib.vllm_flash_attn as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Batch buckets for cuDNN graph caching.
|
||||
# Graphs use batch size and max sequence length as cache key.
|
||||
# This avoids creating a new graph for each unique set of
|
||||
# batch size and max sequence length at runtime.
|
||||
# From the cuDNN team's performance measurements, there
|
||||
# is no significant kernel performance difference between padding
|
||||
# to a smaller batch size/seq length and padding to larger
|
||||
# ones. The bucketing here is solely used to avoid memory
|
||||
# operation overhead, which won't be needed if we have CUDA
|
||||
# graph support in the future.
|
||||
# TODO: Remove buckets after issue #34763
|
||||
# (cuda graph support) is addressed.
|
||||
FLASHINFER_BATCH_BUCKETS = [8, 16, 32, 64]
|
||||
FLASHINFER_MAX_SEQLEN_BUCKETS = [
|
||||
1 * 1024,
|
||||
2 * 1024,
|
||||
4 * 1024,
|
||||
8 * 1024,
|
||||
16 * 1024,
|
||||
32 * 1024,
|
||||
64 * 1024,
|
||||
128 * 1024,
|
||||
]
|
||||
|
||||
# Workspace buffer for FlashInfer CuDNN backend
|
||||
FLASHINFER_CUDNN_WORKSPACE_SIZE_BYTES = 128 * 1024 * 1024
|
||||
_flashinfer_workspace_buffer: torch.Tensor | None = None
|
||||
|
||||
|
||||
def _get_flashinfer_workspace_buffer() -> torch.Tensor:
|
||||
global _flashinfer_workspace_buffer
|
||||
if _flashinfer_workspace_buffer is None:
|
||||
_flashinfer_workspace_buffer = torch.zeros(
|
||||
FLASHINFER_CUDNN_WORKSPACE_SIZE_BYTES,
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
)
|
||||
return _flashinfer_workspace_buffer
|
||||
|
||||
|
||||
def add_padding_to_seqlens(
|
||||
seq: np.ndarray,
|
||||
batch_size: int,
|
||||
padding_value: int,
|
||||
) -> np.ndarray:
|
||||
batch_size_padded = next(
|
||||
(b for b in FLASHINFER_BATCH_BUCKETS if b >= batch_size),
|
||||
round_up(batch_size, FLASHINFER_BATCH_BUCKETS[0]),
|
||||
)
|
||||
if batch_size_padded == batch_size:
|
||||
return seq
|
||||
return np.concatenate(
|
||||
[
|
||||
seq,
|
||||
np.full((batch_size_padded - batch_size,), padding_value, dtype=seq.dtype),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def bucket_flashinfer_max_seqlen(
|
||||
real_max_seqlen: int,
|
||||
) -> int:
|
||||
if real_max_seqlen <= 0:
|
||||
return FLASHINFER_MAX_SEQLEN_BUCKETS[0]
|
||||
return next(
|
||||
(s for s in FLASHINFER_MAX_SEQLEN_BUCKETS if s >= real_max_seqlen),
|
||||
round_up(real_max_seqlen, FLASHINFER_MAX_SEQLEN_BUCKETS[-1]),
|
||||
)
|
||||
|
||||
|
||||
# --8<-- [start:mm_encoder_attn]
|
||||
@CustomOp.register("mm_encoder_attn")
|
||||
@@ -24,6 +97,67 @@ class MMEncoderAttention(CustomOp):
|
||||
"""Multi-headed attention without any cache, used for multimodal encoder."""
|
||||
|
||||
# --8<-- [end:mm_encoder_attn]
|
||||
@classmethod
|
||||
def compute_max_seqlen(
|
||||
cls,
|
||||
attn_backend: AttentionBackendEnum,
|
||||
cu_seqlens: np.ndarray,
|
||||
) -> int:
|
||||
max_seqlen = 0
|
||||
if (
|
||||
attn_backend
|
||||
in (
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
AttentionBackendEnum.FLASHINFER,
|
||||
)
|
||||
and len(cu_seqlens) >= 2
|
||||
):
|
||||
max_seqlen = int((cu_seqlens[1:] - cu_seqlens[:-1]).max())
|
||||
if attn_backend == AttentionBackendEnum.FLASHINFER:
|
||||
max_seqlen = bucket_flashinfer_max_seqlen(max_seqlen)
|
||||
return max_seqlen
|
||||
|
||||
@classmethod
|
||||
def maybe_compute_sequence_lengths(
|
||||
cls,
|
||||
attn_backend: AttentionBackendEnum,
|
||||
cu_seqlens: np.ndarray,
|
||||
) -> np.ndarray | None:
|
||||
if attn_backend != AttentionBackendEnum.FLASHINFER:
|
||||
return None
|
||||
sequence_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
sequence_lengths = add_padding_to_seqlens(
|
||||
sequence_lengths, len(sequence_lengths), 0
|
||||
)
|
||||
return sequence_lengths
|
||||
|
||||
@classmethod
|
||||
def maybe_recompute_cu_seqlens(
|
||||
cls,
|
||||
attn_backend: AttentionBackendEnum,
|
||||
cu_seqlens: np.ndarray,
|
||||
hidden_size: int,
|
||||
tp_size: int,
|
||||
) -> np.ndarray:
|
||||
if attn_backend != AttentionBackendEnum.FLASHINFER:
|
||||
return cu_seqlens
|
||||
|
||||
batch_size = len(cu_seqlens) - 1
|
||||
scale = hidden_size // tp_size
|
||||
cu_seqlens = cu_seqlens * scale
|
||||
|
||||
cu_seqlens_qko = cu_seqlens
|
||||
cu_seqlens_v = cu_seqlens * 3
|
||||
|
||||
cu_seqlens_qko = add_padding_to_seqlens(
|
||||
cu_seqlens_qko, batch_size, cu_seqlens_qko[-1]
|
||||
)
|
||||
cu_seqlens_v = add_padding_to_seqlens(
|
||||
cu_seqlens_v, batch_size, cu_seqlens_v[-1]
|
||||
)
|
||||
return np.concatenate([cu_seqlens_qko, cu_seqlens_v])
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -46,10 +180,9 @@ class MMEncoderAttention(CustomOp):
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = scale
|
||||
self.scale = 1.0 / (head_size**0.5) if scale is None else scale
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.layer_name = prefix
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0, (
|
||||
f"num_heads ({self.num_heads}) is not "
|
||||
f"divisible by num_kv_heads ({self.num_kv_heads})"
|
||||
@@ -72,9 +205,14 @@ class MMEncoderAttention(CustomOp):
|
||||
}
|
||||
|
||||
self._fa_version = (
|
||||
get_flash_attn_version() if self.is_flash_attn_backend else None
|
||||
get_flash_attn_version(head_size=head_size)
|
||||
if self.is_flash_attn_backend
|
||||
else None
|
||||
)
|
||||
|
||||
if self.attn_backend == AttentionBackendEnum.FLASHINFER:
|
||||
_get_flashinfer_workspace_buffer()
|
||||
|
||||
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
|
||||
|
||||
@classmethod
|
||||
@@ -148,23 +286,27 @@ class MMEncoderAttention(CustomOp):
|
||||
bsz, q_len = query.size()[:2]
|
||||
kv_len = key.size(1)
|
||||
is_reshaped = query.dim() != 4
|
||||
query = query.view(bsz * q_len, self.num_heads, self.head_size)
|
||||
key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size)
|
||||
|
||||
query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)
|
||||
|
||||
output = vit_flash_attn_wrapper(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
batch_size=bsz,
|
||||
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
|
||||
fa_version=self._fa_version,
|
||||
scale=self.scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
cu_q = torch.tensor([0,] + [q_len for _ in range(bsz)], device=query.device, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32)
|
||||
cu_kv = torch.tensor([0,] + [kv_len for _ in range(bsz)], device=query.device, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32)
|
||||
out = ops.flash_attn_varlen_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
cu_q,
|
||||
cu_kv,
|
||||
q_len,
|
||||
kv_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
out = out.view(bsz, q_len, self.num_heads, self.head_size)
|
||||
if is_reshaped:
|
||||
output = output.reshape(bsz, q_len, -1)
|
||||
return output
|
||||
out = out.reshape(bsz, q_len, -1)
|
||||
return out
|
||||
|
||||
def _forward_triton(
|
||||
self,
|
||||
@@ -201,6 +343,27 @@ class MMEncoderAttention(CustomOp):
|
||||
output = output.reshape(bsz, q_len, -1)
|
||||
return output
|
||||
|
||||
def _forward_flashinfer(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None,
|
||||
sequence_lengths: torch.Tensor
|
||||
| None = None, # Only used for FlashInfer CuDNN backend
|
||||
) -> torch.Tensor:
|
||||
return vit_flashinfer_wrapper(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
scale=self.scale,
|
||||
workspace_buffer=_get_flashinfer_workspace_buffer(),
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
sequence_lengths=sequence_lengths,
|
||||
)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
@@ -208,6 +371,8 @@ class MMEncoderAttention(CustomOp):
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
sequence_lengths: torch.Tensor
|
||||
| None = None, # Only used for FlashInfer CuDNN backend
|
||||
) -> torch.Tensor:
|
||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||
|
||||
@@ -218,11 +383,17 @@ class MMEncoderAttention(CustomOp):
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
sequence_lengths: torch.Tensor
|
||||
| None = None, # Only used for FlashInfer CuDNN backend
|
||||
) -> torch.Tensor:
|
||||
if self.is_flash_attn_backend:
|
||||
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
|
||||
elif self.attn_backend == AttentionBackendEnum.TRITON_ATTN:
|
||||
return self._forward_triton(query, key, value, cu_seqlens, max_seqlen)
|
||||
elif self.attn_backend == AttentionBackendEnum.FLASHINFER:
|
||||
return self._forward_flashinfer(
|
||||
query, key, value, cu_seqlens, max_seqlen, sequence_lengths
|
||||
)
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||
else:
|
||||
@@ -238,6 +409,8 @@ class MMEncoderAttention(CustomOp):
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
sequence_lengths: torch.Tensor
|
||||
| None = None, # Only used for FlashInfer CuDNN backend
|
||||
) -> torch.Tensor:
|
||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||
|
||||
@@ -248,6 +421,8 @@ class MMEncoderAttention(CustomOp):
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
sequence_lengths: torch.Tensor
|
||||
| None = None, # Only used for FlashInfer CuDNN backend
|
||||
) -> torch.Tensor:
|
||||
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
|
||||
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
|
||||
|
||||
Reference in New Issue
Block a user