[Fix] Address remaining issues of supporting MiniCPMV (#2977)
This commit is contained in:
@@ -166,6 +166,12 @@ def _fwd_kernel(
|
||||
def context_attention_fwd(
|
||||
q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
|
||||
):
|
||||
"""
|
||||
q, k, v: [b * s, head, head_dim]
|
||||
b_start_loc: [b]
|
||||
b_seq_len: [b]
|
||||
out: [b * s, head, head_dim]
|
||||
"""
|
||||
if is_cuda_available and CUDA_CAPABILITY[0] > 8:
|
||||
BLOCK = 128
|
||||
else:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from sglang.srt.distributed import parallel_state
|
||||
@@ -63,7 +64,20 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.T
|
||||
|
||||
|
||||
class VisionAttention(nn.Module):
|
||||
"""Multi-headed attention without any cache, mostly used for ViT."""
|
||||
r"""
|
||||
Multi-headed attention without any cache, mostly used for ViT.
|
||||
|
||||
|
||||
Args:
|
||||
use_qkv_parallel (bool, optional): If True, use QKV-parallel attention.
|
||||
use_context_forward (bool, default to True):
|
||||
if ``True``, a flash_attn style attention will be applied
|
||||
Otherwise, a full-sequence attention will be applied.
|
||||
use_full_precision_softmax (bool, default to False):
|
||||
if ``True``, the softmax will be performed in full-precision
|
||||
Otherwise, it will be performed in half-precision
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -72,25 +86,39 @@ class VisionAttention(nn.Module):
|
||||
projection_size: int,
|
||||
use_qkv_parallel: bool,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
dropout: float = 0.0,
|
||||
use_context_forward: bool = True,
|
||||
use_full_precision_softmax: bool = False,
|
||||
flatten_batch: bool = False,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.use_context_forward = use_context_forward
|
||||
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
||||
|
||||
self.dropout = dropout
|
||||
self.head_size = embed_dim // num_heads
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
projection_size, num_heads
|
||||
)
|
||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||
num_heads, world_size
|
||||
)
|
||||
# self.tp_size = get_tensor_model_parallel_world_size()
|
||||
# num_heads = self.num_heads_per_partition
|
||||
|
||||
if self.use_context_forward:
|
||||
self.qkv_backend = VisionTritonAttention()
|
||||
else:
|
||||
self.qkv_backend = VisionSdpaAttention(
|
||||
head_size=self.head_size,
|
||||
dropout=dropout,
|
||||
flatten_batch=flatten_batch,
|
||||
use_full_precision_softmax=use_full_precision_softmax,
|
||||
)
|
||||
|
||||
self.use_qkv_parallel = use_qkv_parallel
|
||||
if use_qkv_parallel:
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=embed_dim,
|
||||
head_size=self.head_dim,
|
||||
head_size=self.head_size,
|
||||
total_num_heads=num_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
@@ -114,12 +142,15 @@ class VisionAttention(nn.Module):
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
rotary_pos_emb: torch.Tensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
x: [b, s, embed_dim]
|
||||
cu_seqlens: [b]
|
||||
Returns:
|
||||
[s, b, num_heads * head]
|
||||
"""
|
||||
Input shape: [b, s, embed_dim]
|
||||
Output shape: [s, b, num_heads * head_size]
|
||||
"""
|
||||
|
||||
bsz, s, _ = x.shape
|
||||
if self.use_qkv_parallel:
|
||||
# [b, s, embed_dim] --> [b, s, embed_dim]
|
||||
@@ -136,19 +167,19 @@ class VisionAttention(nn.Module):
|
||||
else:
|
||||
# [b, s, embed_dim] --> [s, b, embed_dim]
|
||||
x = rearrange(x, "b s ... -> s b ...")
|
||||
# [s, b, embed_dim] --> [s, b, head * 3 * head_dim]
|
||||
# [s, b, embed_dim] --> [s, b, head * 3 * head_size]
|
||||
qkv, _ = self.qkv_proj(x)
|
||||
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
|
||||
# [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
|
||||
new_x_shape = qkv.size()[:-1] + (
|
||||
self.num_attention_heads_per_partition,
|
||||
3 * self.hidden_size_per_attention_head,
|
||||
)
|
||||
qkv = qkv.view(*new_x_shape)
|
||||
|
||||
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
|
||||
# [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
|
||||
q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
|
||||
|
||||
# [s, b, head, head_dim] --> [b, s, head, head_dim]
|
||||
# [s, b, head, head_size] --> [b, s, head, head_size]
|
||||
q, k, v = [
|
||||
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
|
||||
]
|
||||
@@ -160,45 +191,217 @@ class VisionAttention(nn.Module):
|
||||
if self.use_qkv_parallel:
|
||||
pass
|
||||
else:
|
||||
# [b, s, head, head_dim] --> [b * s, head, head_dim]
|
||||
# [b, s, head, head_size] --> [b * s, head, head_size]
|
||||
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
|
||||
|
||||
# [b * s, num_heads, head_size]
|
||||
output = self.qkv_backend.forward(q, k, v, bsz, cu_seqlens, attention_mask)
|
||||
|
||||
if self.use_qkv_parallel:
|
||||
# [b * s, h, head_size] --> [b, s, h * head_size]
|
||||
output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
|
||||
|
||||
# [b, s, h * head_size] --> [b, s, h * head_size]
|
||||
output, _ = self.proj(output)
|
||||
else:
|
||||
# [b * s, h, head_size] --> [s, b, h * head_size]
|
||||
context_layer = rearrange(
|
||||
output, "(b s) h d -> s b (h d)", b=bsz, s=s
|
||||
).contiguous()
|
||||
|
||||
# [s, b, h * head_size] --> [s, b, h * head_size]
|
||||
output, _ = self.proj(context_layer)
|
||||
|
||||
# [s, b, h * head_size] --> [b, s, h * head_size]
|
||||
output = output.view(bsz, s, -1)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class VisionSdpaAttention(nn.Module):
|
||||
r"""
|
||||
Scaled Dot Product Attention inner product
|
||||
|
||||
"""
|
||||
|
||||
# TODO: Should it be released after used?
|
||||
_mask_cache = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
dropout: float = 0.0,
|
||||
flatten_batch: bool = False,
|
||||
use_full_precision_softmax: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
self.flatten_batch = flatten_batch
|
||||
self.use_full_precision_softmax = use_full_precision_softmax
|
||||
self.dropout = dropout
|
||||
|
||||
def generate_patch_attention_mask(
|
||||
self,
|
||||
s: int,
|
||||
bsz: int,
|
||||
device,
|
||||
cu_seqlens: Optional[torch.Tensor],
|
||||
flatten_batch: bool = False,
|
||||
dtype=torch.bfloat16,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
|
||||
|
||||
When `flatten_batch` is True:
|
||||
- All sequences in the batch are flattened into a single dimension
|
||||
- `s` represents the total number of tokens across all sequences in the batch
|
||||
- Returns a unified mask of shape `(1, 1, s, s)`
|
||||
|
||||
When `flatten_batch` is False:
|
||||
- Each sequence has its own attention mask
|
||||
- `s` represents the maximum sequence length in the batch
|
||||
- Returns separate masks of shape `(b, 1, s, s)`
|
||||
|
||||
Args:
|
||||
flatten_batch: (bool):
|
||||
If True, treats all sequences in the batch as a single flattened sequence
|
||||
If False, generates separate masks for each sequence
|
||||
|
||||
Returns:
|
||||
Tensor of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
|
||||
"""
|
||||
|
||||
cache_key = (s, bsz, flatten_batch, tuple(cu_seqlens.cpu().tolist()))
|
||||
|
||||
if cache_key in VisionSdpaAttention._mask_cache:
|
||||
cached_mask = VisionSdpaAttention._mask_cache[cache_key]
|
||||
# print(f"cache hit for key: {cache_key}")
|
||||
return cached_mask.to(device=device, dtype=dtype)
|
||||
|
||||
if cu_seqlens is None:
|
||||
raise ValueError("Internal Error: cu_seqlens cannot be None")
|
||||
|
||||
if flatten_batch:
|
||||
mask = torch.zeros([1, s, s], device=device, dtype=torch.bool)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
start = cu_seqlens[i - 1]
|
||||
end = cu_seqlens[i]
|
||||
mask[
|
||||
...,
|
||||
start:end,
|
||||
start:end,
|
||||
] = True
|
||||
else:
|
||||
# [1, 1, 1, s]
|
||||
row_indices = torch.arange(s, device=device).view(1, 1, 1, s)
|
||||
# [1, 1, s, 1]
|
||||
col_indices = torch.arange(s, device=device).view(1, 1, s, 1)
|
||||
# [b, 1, 1, 1]
|
||||
seq_lens = (
|
||||
(cu_seqlens[1:] - cu_seqlens[:-1]).to(device=device).view(-1, 1, 1, 1)
|
||||
)
|
||||
|
||||
mask = (row_indices < seq_lens) & (col_indices < seq_lens)
|
||||
|
||||
# Convert to attention mask format (False -> 0, True -> -inf)
|
||||
mask = (~mask).to(dtype) * torch.finfo(dtype).min
|
||||
|
||||
VisionSdpaAttention._mask_cache[cache_key] = mask
|
||||
|
||||
return mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
bsz: int,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
cu_seqlens: [b]
|
||||
Returns:
|
||||
[b * s, h, head_size]
|
||||
"""
|
||||
|
||||
s = q.shape[0] // bsz
|
||||
|
||||
# [b, 1, s, s]
|
||||
if attention_mask is None:
|
||||
attention_mask = self.generate_patch_attention_mask(
|
||||
s, bsz, q.device, cu_seqlens, self.flatten_batch, q.dtype
|
||||
)
|
||||
q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
|
||||
# [b, 1, s]
|
||||
if self.use_full_precision_softmax:
|
||||
scale = self.head_size**-0.5
|
||||
k_transposed = rearrange(k, "b h s d -> b h d s")
|
||||
attn_weights = torch.matmul(q, k_transposed) * scale
|
||||
del k, k_transposed
|
||||
attn_weights = attn_weights + attention_mask
|
||||
del attention_mask
|
||||
# full-precision
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32
|
||||
).to(q.dtype)
|
||||
attn_weights = nn.functional.dropout(
|
||||
attn_weights, p=self.dropout, training=False
|
||||
)
|
||||
output = torch.matmul(attn_weights, v)
|
||||
del attn_weights, v
|
||||
else:
|
||||
# SDPA
|
||||
# [b, h, s, head_size]
|
||||
output = F.scaled_dot_product_attention(
|
||||
q, k, v, attention_mask, dropout_p=self.dropout
|
||||
)
|
||||
|
||||
# [b, h, s, head_size] --> [b * s, h, head_size]
|
||||
output = rearrange(output, "b h s d -> (b s) h d")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class VisionTritonAttention(nn.Module):
|
||||
"""
|
||||
Triton-implemented attention without a causal mask
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
_bsz: int,
|
||||
cu_seqlens: Optional[torch.Tensor],
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
cu_seqlens: [b]
|
||||
Returns:
|
||||
[b * s, h, head_size]
|
||||
"""
|
||||
|
||||
# [b * s, head, head_size]
|
||||
output = torch.empty_like(q)
|
||||
|
||||
seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cuda()
|
||||
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
max_seqlen = seq_lens.max().item()
|
||||
|
||||
context_attention_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
cu_seqlens.cuda(),
|
||||
seq_lens,
|
||||
seq_lens.cuda(),
|
||||
max_seqlen,
|
||||
is_causal=False,
|
||||
)
|
||||
|
||||
if self.use_qkv_parallel:
|
||||
|
||||
# [b * s, head, head_dim] --> [b, s, head * head_dim]
|
||||
output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
|
||||
|
||||
# [b, s, head, head_dim] --> [b, s, head, head_dim]
|
||||
output, _ = self.proj(output)
|
||||
else:
|
||||
# [b * s, head, head_dim] --> [b, s, head, head_dim]
|
||||
context_layer = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
|
||||
# [s, b, num_heads * head_size]
|
||||
context_layer = rearrange(
|
||||
context_layer, "b s h d -> s b (h d)"
|
||||
).contiguous()
|
||||
|
||||
# [s, b, num_heads * head_size] --> [s, b, num_heads * head_size]
|
||||
output, _ = self.proj(context_layer)
|
||||
|
||||
output = output.view(bsz, s, -1)
|
||||
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user