[Feature] Add FlashAttention3 as a backend for VisionAttention (#5764)
Co-authored-by: othame <chenzhu_912@zju.edu.cn> Co-authored-by: Mick <mickjagger19@icloud.com> Co-authored-by: Yi Zhang <1109276519@qq.com>
This commit is contained in:
@@ -166,10 +166,11 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
|||||||
|
|
||||||
## Kernel backend
|
## Kernel backend
|
||||||
|
|
||||||
| Arguments | Description | Defaults |
|
| Arguments | Description | Defaults |
|
||||||
|----------|-------------|---------|
|
|------------------------|-------------|---------|
|
||||||
| `attention_backend` | This argument specifies the backend for attention computation and KV cache management, which can be `fa3`, `flashinfer`, `triton`, `flashmla`, `cutlass_mla`, or `torch_native`. When deploying DeepSeek models, use this argument to specify the MLA backend. | None |
|
| `attention_backend` | This argument specifies the backend for attention computation and KV cache management, which can be `fa3`, `flashinfer`, `triton`, `cutlass_mla`, or `torch_native`. When deploying DeepSeek models, use this argument to specify the MLA backend. | None |
|
||||||
| `sampling_backend` | Specifies the backend used for sampling. | None |
|
| `sampling_backend` | Specifies the backend used for sampling. | None |
|
||||||
|
| `mm_attention_backend` | Set multimodal attention backend.
|
||||||
|
|
||||||
## Constrained Decoding
|
## Constrained Decoding
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from functools import lru_cache
|
import math
|
||||||
|
from functools import lru_cache, wraps
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -8,6 +9,13 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
|
from sglang.srt.utils import is_cuda
|
||||||
|
|
||||||
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
|
if _is_cuda:
|
||||||
|
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
||||||
|
|
||||||
from sglang.srt.distributed import parallel_state
|
from sglang.srt.distributed import parallel_state
|
||||||
from sglang.srt.distributed import utils as dist_utils
|
from sglang.srt.distributed import utils as dist_utils
|
||||||
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
||||||
@@ -19,166 +27,31 @@ from sglang.srt.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization import QuantizationConfig
|
from sglang.srt.layers.quantization import QuantizationConfig
|
||||||
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb, rotate_half
|
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
|
||||||
from sglang.srt.utils import add_prefix
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
|
from sglang.srt.utils import add_prefix, logger
|
||||||
|
|
||||||
|
ROTARY_EMBED_CLASSES = {
|
||||||
|
"normal": apply_rotary_pos_emb,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class VisionAttention(nn.Module):
|
def execute_once(func):
|
||||||
r"""
|
has_run = None
|
||||||
Multi-headed attention without any cache, mostly used for ViT.
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
nonlocal has_run
|
||||||
|
if not has_run:
|
||||||
|
func(*args, **kwargs)
|
||||||
|
has_run = True
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
Args:
|
@execute_once
|
||||||
use_qkv_parallel (bool, optional): If True, use QKV-parallel attention.
|
def info_once(message: str):
|
||||||
use_context_forward (bool, default to True):
|
logger.info(message)
|
||||||
if ``True``, a flash_attn style attention will be applied
|
|
||||||
Otherwise, a full-sequence attention will be applied.
|
|
||||||
softmax_in_single_precision (bool, default to False):
|
|
||||||
if ``True``, the softmax will be performed in single-precision
|
|
||||||
Otherwise, it will be performed in half-precision
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
embed_dim: int,
|
|
||||||
num_heads: int,
|
|
||||||
projection_size: int,
|
|
||||||
use_qkv_parallel: bool,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
dropout: float = 0.0,
|
|
||||||
use_context_forward: bool = True,
|
|
||||||
softmax_in_single_precision: bool = False,
|
|
||||||
flatten_batch: bool = False,
|
|
||||||
prefix: str = "",
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.use_context_forward = use_context_forward
|
|
||||||
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
|
||||||
self.dropout = dropout
|
|
||||||
self.head_size = embed_dim // num_heads
|
|
||||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
|
||||||
projection_size, num_heads
|
|
||||||
)
|
|
||||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
|
||||||
num_heads, world_size
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.use_context_forward:
|
|
||||||
self.qkv_backend = VisionTritonAttention()
|
|
||||||
else:
|
|
||||||
self.qkv_backend = VisionSdpaAttention(
|
|
||||||
head_size=self.head_size,
|
|
||||||
dropout=dropout,
|
|
||||||
flatten_batch=flatten_batch,
|
|
||||||
softmax_in_single_precision=softmax_in_single_precision,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.use_qkv_parallel = use_qkv_parallel
|
|
||||||
if use_qkv_parallel:
|
|
||||||
self.qkv_proj = QKVParallelLinear(
|
|
||||||
hidden_size=embed_dim,
|
|
||||||
head_size=self.head_size,
|
|
||||||
total_num_heads=num_heads,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=add_prefix("qkv_proj", prefix),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.qkv_proj = ColumnParallelLinear(
|
|
||||||
input_size=embed_dim,
|
|
||||||
output_size=3 * projection_size,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=add_prefix("qkv_proj", prefix),
|
|
||||||
)
|
|
||||||
self.proj = RowParallelLinear(
|
|
||||||
input_size=embed_dim,
|
|
||||||
output_size=embed_dim,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=add_prefix("proj", prefix),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
cu_seqlens: Optional[torch.Tensor] = None,
|
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
x: [b, s, embed_dim]
|
|
||||||
cu_seqlens: [b]
|
|
||||||
Returns:
|
|
||||||
[s, b, head * head_size]
|
|
||||||
"""
|
|
||||||
bsz, s, _ = x.shape
|
|
||||||
head = self.num_attention_heads_per_partition
|
|
||||||
if self.use_qkv_parallel:
|
|
||||||
# [b, s, embed_dim] --> [b, s, embed_dim]
|
|
||||||
qkv, _ = self.qkv_proj(x)
|
|
||||||
q, k, v = qkv.chunk(3, dim=-1)
|
|
||||||
|
|
||||||
# [b, s, embed_dim] --> [b * s, head, head_size]
|
|
||||||
q, k, v = [x.reshape(bsz * s, head, -1).contiguous() for x in (q, k, v)]
|
|
||||||
else:
|
|
||||||
# [b, s, embed_dim] --> [s, b, embed_dim]
|
|
||||||
x = rearrange(x, "b s ... -> s b ...")
|
|
||||||
# [s, b, embed_dim] --> [s, b, head * 3 * head_size]
|
|
||||||
qkv, _ = self.qkv_proj(x)
|
|
||||||
# [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
|
|
||||||
new_x_shape = qkv.size()[:-1] + (
|
|
||||||
head,
|
|
||||||
3 * self.hidden_size_per_attention_head,
|
|
||||||
)
|
|
||||||
qkv = qkv.view(*new_x_shape)
|
|
||||||
|
|
||||||
# [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
|
|
||||||
q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
|
|
||||||
|
|
||||||
# [s, b, head, head_size] --> [b, s, head, head_size]
|
|
||||||
q, k, v = [
|
|
||||||
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
|
|
||||||
]
|
|
||||||
|
|
||||||
if position_embeddings is not None:
|
|
||||||
cos, sin = position_embeddings
|
|
||||||
original_shape = q.shape
|
|
||||||
# [total_tokens, head, head_size]
|
|
||||||
q = q.view(-1, head, self.head_size)
|
|
||||||
k = k.view(-1, head, self.head_size)
|
|
||||||
|
|
||||||
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
|
||||||
|
|
||||||
q = q.view(original_shape)
|
|
||||||
k = k.view(original_shape)
|
|
||||||
|
|
||||||
if self.use_qkv_parallel:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
# [b, s, head, head_size] --> [b * s, head, head_size]
|
|
||||||
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
|
|
||||||
|
|
||||||
output = self.qkv_backend.forward(q, k, v, bsz, cu_seqlens, attention_mask)
|
|
||||||
|
|
||||||
if self.use_qkv_parallel:
|
|
||||||
# [b * s, h, head_size] --> [b, s, h * head_size]
|
|
||||||
output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
|
|
||||||
|
|
||||||
# [b, s, h * head_size] --> [b, s, h * head_size]
|
|
||||||
output, _ = self.proj(output)
|
|
||||||
else:
|
|
||||||
# [b * s, h, head_size] --> [s, b, h * head_size]
|
|
||||||
context_layer = rearrange(
|
|
||||||
output, "(b s) h d -> s b (h d)", b=bsz, s=s
|
|
||||||
).contiguous()
|
|
||||||
|
|
||||||
# [s, b, h * head_size] --> [s, b, h * head_size]
|
|
||||||
output, _ = self.proj(context_layer)
|
|
||||||
|
|
||||||
# [s, b, h * head_size] --> [b, s, h * head_size]
|
|
||||||
output = output.view(bsz, s, -1)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class VisionSdpaAttention(nn.Module):
|
class VisionSdpaAttention(nn.Module):
|
||||||
@@ -189,16 +62,22 @@ class VisionSdpaAttention(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
head_size: int,
|
head_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
flatten_batch: bool = False,
|
flatten_batch: bool = False,
|
||||||
softmax_in_single_precision: bool = False,
|
softmax_in_single_precision: bool = False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.head_size = head_size
|
self.head_size = head_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_kv_heads = num_kv_heads
|
||||||
self.flatten_batch = flatten_batch
|
self.flatten_batch = flatten_batch
|
||||||
self.softmax_in_single_precision = softmax_in_single_precision
|
self.softmax_in_single_precision = softmax_in_single_precision
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
self.scale = 1.0 / math.sqrt(self.head_size)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@lru_cache(maxsize=128)
|
@lru_cache(maxsize=128)
|
||||||
@@ -212,7 +91,7 @@ class VisionSdpaAttention(nn.Module):
|
|||||||
flatten_batch: whether to flatten batch dimension
|
flatten_batch: whether to flatten batch dimension
|
||||||
cu_seqlens: tuple of cumulative sequence lengths
|
cu_seqlens: tuple of cumulative sequence lengths
|
||||||
Returns:
|
Returns:
|
||||||
attention mask tensor
|
attention mask tensor of shape [b, 1, s, s] or [1, s, s]
|
||||||
"""
|
"""
|
||||||
if flatten_batch:
|
if flatten_batch:
|
||||||
mask = torch.zeros([1, s, s], dtype=torch.bool)
|
mask = torch.zeros([1, s, s], dtype=torch.bool)
|
||||||
@@ -241,7 +120,7 @@ class VisionSdpaAttention(nn.Module):
|
|||||||
flatten_batch: bool = False,
|
flatten_batch: bool = False,
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
r"""
|
r"""
|
||||||
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
|
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, s, s)`.
|
||||||
Args:
|
Args:
|
||||||
s: sequence length
|
s: sequence length
|
||||||
cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask
|
cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask
|
||||||
@@ -264,6 +143,7 @@ class VisionSdpaAttention(nn.Module):
|
|||||||
bsz: int,
|
bsz: int,
|
||||||
cu_seqlens: Optional[torch.Tensor] = None,
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@@ -274,6 +154,8 @@ class VisionSdpaAttention(nn.Module):
|
|||||||
if self.flatten_batch:
|
if self.flatten_batch:
|
||||||
assert bsz == 1, "flatten_batch is True, bsz must be 1"
|
assert bsz == 1, "flatten_batch is True, bsz must be 1"
|
||||||
|
|
||||||
|
assert q.dim() == 3, q.shape
|
||||||
|
|
||||||
s = q.shape[0] // bsz
|
s = q.shape[0] // bsz
|
||||||
|
|
||||||
# [b, 1, s, s]
|
# [b, 1, s, s]
|
||||||
@@ -291,10 +173,10 @@ class VisionSdpaAttention(nn.Module):
|
|||||||
q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
|
q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
|
||||||
|
|
||||||
if self.softmax_in_single_precision:
|
if self.softmax_in_single_precision:
|
||||||
scale = self.head_size**-0.5
|
k = rearrange(k, "b h s d -> b h d s")
|
||||||
k_transposed = rearrange(k, "b h s d -> b h d s")
|
attn_weights = torch.matmul(q, k) * self.scale
|
||||||
attn_weights = torch.matmul(q, k_transposed) * scale
|
del k
|
||||||
del k, k_transposed
|
# masking
|
||||||
attention_mask = (~attention_mask) * torch.finfo(q.dtype).min
|
attention_mask = (~attention_mask) * torch.finfo(q.dtype).min
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
del attention_mask
|
del attention_mask
|
||||||
@@ -332,6 +214,7 @@ class VisionTritonAttention(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -340,8 +223,8 @@ class VisionTritonAttention(nn.Module):
|
|||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
_bsz: int,
|
|
||||||
cu_seqlens: Optional[torch.Tensor],
|
cu_seqlens: Optional[torch.Tensor],
|
||||||
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@@ -366,3 +249,247 @@ class VisionTritonAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class VisionFlash3Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if not _is_cuda:
|
||||||
|
raise Exception("VisionFlash3Attention is only available for cuda")
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
cu_seqlens: Optional[torch.Tensor],
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
cu_seqlens: [b]
|
||||||
|
Returns:
|
||||||
|
[b * s, h, head_size]
|
||||||
|
"""
|
||||||
|
cu_seqlens = cu_seqlens.to(dtype=torch.int32).cuda()
|
||||||
|
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||||
|
max_seqlen = seq_lens.max().item()
|
||||||
|
output = flash_attn_varlen_func(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q=cu_seqlens,
|
||||||
|
cu_seqlens_k=cu_seqlens,
|
||||||
|
max_seqlen_q=max_seqlen,
|
||||||
|
max_seqlen_k=max_seqlen,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
QKV_BACKEND_IMPL = {
|
||||||
|
"triton_attn": VisionTritonAttention,
|
||||||
|
"sdpa": VisionSdpaAttention,
|
||||||
|
"fa3": VisionFlash3Attention,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class VisionAttention(nn.Module):
|
||||||
|
r"""
|
||||||
|
Multi-headed attention without any cache, mostly used for multimodal transformers.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_qkv_parallel (bool, optional): If True, use QKV-parallel attention.
|
||||||
|
softmax_in_single_precision (bool, default to False):
|
||||||
|
if ``True``, the softmax will be performed in single-precision
|
||||||
|
Otherwise, it will be performed in half-precision
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
projection_size: int,
|
||||||
|
use_qkv_parallel: bool,
|
||||||
|
qkv_backend: Optional[str] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
softmax_in_single_precision: bool = False,
|
||||||
|
flatten_batch: bool = False,
|
||||||
|
prefix: str = "",
|
||||||
|
proj_bias: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
||||||
|
self.dropout = dropout
|
||||||
|
self.head_size = embed_dim // num_heads
|
||||||
|
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||||
|
projection_size, num_heads
|
||||||
|
)
|
||||||
|
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||||
|
num_heads, world_size
|
||||||
|
)
|
||||||
|
self.num_attention_kv_heads_per_partition = dist_utils.divide(
|
||||||
|
num_heads, world_size
|
||||||
|
)
|
||||||
|
|
||||||
|
self.q_size = self.num_attention_heads_per_partition * self.head_size
|
||||||
|
self.kv_size = self.num_attention_kv_heads_per_partition * self.head_size
|
||||||
|
|
||||||
|
if global_server_args_dict["mm_attention_backend"] is None:
|
||||||
|
if qkv_backend is None:
|
||||||
|
qkv_backend = "sdpa"
|
||||||
|
info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
|
||||||
|
else:
|
||||||
|
qkv_backend = global_server_args_dict["mm_attention_backend"]
|
||||||
|
|
||||||
|
info_once(f"Using {qkv_backend} as multimodal attention backend.")
|
||||||
|
|
||||||
|
self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend](
|
||||||
|
head_dim=self.head_size,
|
||||||
|
num_heads=self.num_attention_heads_per_partition,
|
||||||
|
num_kv_heads=self.num_attention_kv_heads_per_partition,
|
||||||
|
dropout=dropout,
|
||||||
|
flatten_batch=flatten_batch,
|
||||||
|
softmax_in_single_precision=softmax_in_single_precision,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.use_qkv_parallel = use_qkv_parallel
|
||||||
|
if use_qkv_parallel:
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
hidden_size=embed_dim,
|
||||||
|
head_size=self.head_size,
|
||||||
|
total_num_heads=num_heads,
|
||||||
|
total_num_kv_heads=num_heads,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("qkv_proj", prefix),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.qkv_proj = ColumnParallelLinear(
|
||||||
|
input_size=embed_dim,
|
||||||
|
output_size=3 * projection_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("qkv_proj", prefix),
|
||||||
|
)
|
||||||
|
self.proj = RowParallelLinear(
|
||||||
|
input_size=embed_dim,
|
||||||
|
output_size=embed_dim,
|
||||||
|
bias=proj_bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("proj", prefix),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
x: [b, s, embed_dim]
|
||||||
|
cu_seqlens: [b]
|
||||||
|
Returns:
|
||||||
|
[s, b, head * head_size]
|
||||||
|
"""
|
||||||
|
if x.dim() == 2:
|
||||||
|
x = x.unsqueeze(0)
|
||||||
|
assert x.dim() == 3, x.shape
|
||||||
|
bsz, s, _ = x.shape
|
||||||
|
head = self.num_attention_heads_per_partition
|
||||||
|
kv_head = self.num_attention_kv_heads_per_partition
|
||||||
|
if self.use_qkv_parallel:
|
||||||
|
# [b, s, embed_dim] --> [b, s, embed_dim]
|
||||||
|
qkv, _ = self.qkv_proj(x)
|
||||||
|
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
|
||||||
|
# [b, s, embed_dim] --> [b * s, head, head_size]
|
||||||
|
q = q.reshape(bsz * s, head, -1).contiguous()
|
||||||
|
k = k.reshape(bsz * s, kv_head, -1).contiguous()
|
||||||
|
v = v.reshape(bsz * s, kv_head, -1).contiguous()
|
||||||
|
else:
|
||||||
|
# [b, s, embed_dim] --> [s, b, embed_dim]
|
||||||
|
x = rearrange(x, "b s ... -> s b ...")
|
||||||
|
# [s, b, embed_dim] --> [s, b, head * 3 * head_size]
|
||||||
|
qkv, _ = self.qkv_proj(x)
|
||||||
|
|
||||||
|
# [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
|
||||||
|
new_x_shape = qkv.size()[:-1] + (
|
||||||
|
head,
|
||||||
|
3 * self.hidden_size_per_attention_head,
|
||||||
|
)
|
||||||
|
qkv = qkv.view(*new_x_shape)
|
||||||
|
|
||||||
|
# [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
|
||||||
|
q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
|
||||||
|
# [s, b, head, head_size] --> [b, s, head, head_size]
|
||||||
|
q, k, v = [
|
||||||
|
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
|
||||||
|
]
|
||||||
|
|
||||||
|
if position_embeddings is not None:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
original_shape = q.shape
|
||||||
|
# [total_tokens, head, head_size]
|
||||||
|
q = q.view(-1, head, self.head_size)
|
||||||
|
k = k.view(-1, head, self.head_size)
|
||||||
|
|
||||||
|
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
||||||
|
|
||||||
|
q = q.view(original_shape)
|
||||||
|
k = k.view(original_shape)
|
||||||
|
|
||||||
|
if q.dim() == 4:
|
||||||
|
# [b, s, head, head_size] --> [b * s, head, head_size]
|
||||||
|
q = rearrange(q, "b s ... -> (b s) ...")
|
||||||
|
if k.dim() == 4:
|
||||||
|
# [b, s, head, head_size] --> [b * s, head, head_size]
|
||||||
|
k = rearrange(k, "b s ... -> (b s) ...")
|
||||||
|
if v.dim() == 4:
|
||||||
|
# [b, s, head, head_size] --> [b * s, head, head_size]
|
||||||
|
v = rearrange(v, "b s ... -> (b s) ...")
|
||||||
|
|
||||||
|
assert q.dim() == 3, q.dim()
|
||||||
|
assert k.dim() == 3, k.dim()
|
||||||
|
assert v.dim() == 3, v.dim()
|
||||||
|
|
||||||
|
output = self.qkv_backend.forward(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=v,
|
||||||
|
bsz=bsz,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert output.dim() == 3, output.shape
|
||||||
|
|
||||||
|
if self.use_qkv_parallel:
|
||||||
|
# [b * s, h, head_size] --> [b, s, h * head_size]
|
||||||
|
output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
|
||||||
|
|
||||||
|
# [b, s, h * head_size] --> [b, s, h * head_size]
|
||||||
|
output, _ = self.proj(output)
|
||||||
|
else:
|
||||||
|
# [b * s, h, head_size] --> [s, b, h * head_size]
|
||||||
|
context_layer = rearrange(
|
||||||
|
output, "(b s) h d -> s b (h d)", b=bsz, s=s
|
||||||
|
).contiguous()
|
||||||
|
|
||||||
|
# [s, b, h * head_size] --> [s, b, h * head_size]
|
||||||
|
output, _ = self.proj(context_layer)
|
||||||
|
|
||||||
|
# [s, b, h * head_size] --> [b, s, h * head_size]
|
||||||
|
output = output.view(bsz, s, -1)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|||||||
@@ -173,6 +173,7 @@ class ModelRunner:
|
|||||||
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
||||||
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
||||||
"use_mla_backend": self.use_mla_backend,
|
"use_mla_backend": self.use_mla_backend,
|
||||||
|
"mm_attention_backend": server_args.mm_attention_backend,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -151,20 +151,20 @@ class CLIPEncoderLayer(nn.Module):
|
|||||||
self.layer_norm1 = norm_layer(config.hidden_size)
|
self.layer_norm1 = norm_layer(config.hidden_size)
|
||||||
self.layer_norm2 = norm_layer(config.hidden_size)
|
self.layer_norm2 = norm_layer(config.hidden_size)
|
||||||
if attn_implementation == "sdpa":
|
if attn_implementation == "sdpa":
|
||||||
use_context_forward = False
|
qkv_backend = "sdpa"
|
||||||
softmax_in_single_precision = False
|
softmax_in_single_precision = False
|
||||||
elif attn_implementation == "flash_attention_2":
|
elif attn_implementation == "flash_attention_2":
|
||||||
|
qkv_backend = "triton_attn"
|
||||||
softmax_in_single_precision = False
|
softmax_in_single_precision = False
|
||||||
use_context_forward = True
|
|
||||||
elif attn_implementation == "eager":
|
elif attn_implementation == "eager":
|
||||||
|
qkv_backend = "sdpa"
|
||||||
softmax_in_single_precision = True
|
softmax_in_single_precision = True
|
||||||
use_context_forward = False
|
|
||||||
self.self_attn = VisionAttention(
|
self.self_attn = VisionAttention(
|
||||||
embed_dim=config.hidden_size,
|
embed_dim=config.hidden_size,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
projection_size=config.hidden_size,
|
projection_size=config.hidden_size,
|
||||||
use_qkv_parallel=True,
|
use_qkv_parallel=True,
|
||||||
use_context_forward=use_context_forward,
|
qkv_backend=qkv_backend,
|
||||||
softmax_in_single_precision=softmax_in_single_precision,
|
softmax_in_single_precision=softmax_in_single_precision,
|
||||||
flatten_batch=True,
|
flatten_batch=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
|||||||
@@ -532,7 +532,7 @@ class VisionTransformerBlock(nn.Module):
|
|||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
projection_size=dim,
|
projection_size=dim,
|
||||||
use_qkv_parallel=True,
|
use_qkv_parallel=True,
|
||||||
use_context_forward=False,
|
qkv_backend="sdpa",
|
||||||
softmax_in_single_precision=False,
|
softmax_in_single_precision=False,
|
||||||
dropout=attn_drop,
|
dropout=attn_drop,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -281,7 +281,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|||||||
pixel_values = torch.stack(
|
pixel_values = torch.stack(
|
||||||
flatten_nested_list([item.pixel_values for item in items]), dim=0
|
flatten_nested_list([item.pixel_values for item in items]), dim=0
|
||||||
)
|
)
|
||||||
pixel_values = pixel_values.to("cuda")
|
pixel_values = pixel_values.to(device=self.vision_tower.device)
|
||||||
pixel_values = pixel_values.to(dtype=self.language_model.dtype())
|
pixel_values = pixel_values.to(dtype=self.language_model.dtype())
|
||||||
|
|
||||||
vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
|
vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
|
||||||
|
|||||||
@@ -197,7 +197,7 @@ class Idefics2EncoderLayer(nn.Module):
|
|||||||
use_qkv_parallel=True,
|
use_qkv_parallel=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
dropout=config.attention_dropout,
|
dropout=config.attention_dropout,
|
||||||
use_context_forward=False,
|
qkv_backend="sdpa",
|
||||||
softmax_in_single_precision=True,
|
softmax_in_single_precision=True,
|
||||||
flatten_batch=False,
|
flatten_batch=False,
|
||||||
prefix=add_prefix("self_attn", prefix),
|
prefix=add_prefix("self_attn", prefix),
|
||||||
|
|||||||
@@ -203,7 +203,7 @@ class MllamaVisionEncoderLayer(nn.Module):
|
|||||||
use_qkv_parallel=True,
|
use_qkv_parallel=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
dropout=0.0,
|
dropout=0.0,
|
||||||
use_context_forward=False,
|
qkv_backend="sdpa",
|
||||||
softmax_in_single_precision=False,
|
softmax_in_single_precision=False,
|
||||||
flatten_batch=False,
|
flatten_batch=False,
|
||||||
prefix=add_prefix("self_attn", prefix),
|
prefix=add_prefix("self_attn", prefix),
|
||||||
|
|||||||
@@ -125,16 +125,20 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|||||||
self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
|
self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
|
||||||
self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
|
self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
|
||||||
if attn_implementation == "sdpa":
|
if attn_implementation == "sdpa":
|
||||||
use_context_forward = False
|
|
||||||
softmax_in_single_precision = False
|
softmax_in_single_precision = False
|
||||||
|
qkv_backend = "sdpa"
|
||||||
flatten_batch = True
|
flatten_batch = True
|
||||||
elif attn_implementation == "flash_attention_2":
|
elif attn_implementation == "flash_attention_2":
|
||||||
softmax_in_single_precision = False
|
softmax_in_single_precision = False
|
||||||
use_context_forward = True
|
qkv_backend = "triton_attn"
|
||||||
flatten_batch = True
|
flatten_batch = True
|
||||||
elif attn_implementation == "eager":
|
elif attn_implementation == "eager":
|
||||||
softmax_in_single_precision = True
|
softmax_in_single_precision = True
|
||||||
use_context_forward = False
|
qkv_backend = "sdpa"
|
||||||
|
flatten_batch = True
|
||||||
|
elif attn_implementation == "flash_attention_3":
|
||||||
|
softmax_in_single_precision = False
|
||||||
|
qkv_backend = "fa3"
|
||||||
flatten_batch = True
|
flatten_batch = True
|
||||||
|
|
||||||
self.attn = VisionAttention(
|
self.attn = VisionAttention(
|
||||||
@@ -142,7 +146,7 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
projection_size=dim,
|
projection_size=dim,
|
||||||
use_qkv_parallel=True,
|
use_qkv_parallel=True,
|
||||||
use_context_forward=use_context_forward,
|
qkv_backend=qkv_backend,
|
||||||
softmax_in_single_precision=softmax_in_single_precision,
|
softmax_in_single_precision=softmax_in_single_precision,
|
||||||
flatten_batch=flatten_batch,
|
flatten_batch=flatten_batch,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
|||||||
@@ -139,21 +139,21 @@ class Qwen2VisionBlock(nn.Module):
|
|||||||
self.norm2 = norm_layer(dim)
|
self.norm2 = norm_layer(dim)
|
||||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
if attn_implementation == "sdpa":
|
if attn_implementation == "sdpa":
|
||||||
use_context_forward = False
|
qkv_backend = "sdpa"
|
||||||
softmax_in_single_precision = False
|
softmax_in_single_precision = False
|
||||||
elif attn_implementation == "flash_attention_2":
|
elif attn_implementation == "flash_attention_2":
|
||||||
|
qkv_backend = "triton_attn"
|
||||||
softmax_in_single_precision = False
|
softmax_in_single_precision = False
|
||||||
use_context_forward = True
|
|
||||||
elif attn_implementation == "eager":
|
elif attn_implementation == "eager":
|
||||||
|
qkv_backend = "sdpa"
|
||||||
softmax_in_single_precision = True
|
softmax_in_single_precision = True
|
||||||
use_context_forward = False
|
|
||||||
|
|
||||||
self.attn = VisionAttention(
|
self.attn = VisionAttention(
|
||||||
embed_dim=dim,
|
embed_dim=dim,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
projection_size=dim,
|
projection_size=dim,
|
||||||
use_qkv_parallel=True,
|
use_qkv_parallel=True,
|
||||||
use_context_forward=use_context_forward,
|
qkv_backend=qkv_backend,
|
||||||
softmax_in_single_precision=softmax_in_single_precision,
|
softmax_in_single_precision=softmax_in_single_precision,
|
||||||
flatten_batch=True,
|
flatten_batch=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
|||||||
@@ -187,6 +187,7 @@ class ServerArgs:
|
|||||||
n_share_experts_fusion: int = 0
|
n_share_experts_fusion: int = 0
|
||||||
disable_chunked_prefix_cache: bool = False
|
disable_chunked_prefix_cache: bool = False
|
||||||
disable_fast_image_processor: bool = False
|
disable_fast_image_processor: bool = False
|
||||||
|
mm_attention_backend: Optional[str] = None
|
||||||
|
|
||||||
# Debug tensor dumps
|
# Debug tensor dumps
|
||||||
debug_tensor_dump_output_folder: Optional[str] = None
|
debug_tensor_dump_output_folder: Optional[str] = None
|
||||||
@@ -1265,6 +1266,14 @@ class ServerArgs:
|
|||||||
help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
|
help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--mm-attention-backend",
|
||||||
|
type=str,
|
||||||
|
choices=["sdpa", "fa3", "triton_attn"],
|
||||||
|
default=ServerArgs.mm_attention_backend,
|
||||||
|
help="Set multimodal attention backend.",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
args.tp_size = args.tensor_parallel_size
|
args.tp_size = args.tensor_parallel_size
|
||||||
|
|||||||
Reference in New Issue
Block a user