211 lines
7.5 KiB
Python
211 lines
7.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import importlib
|
|
from functools import lru_cache
|
|
|
|
import torch
|
|
|
|
from vllm._aiter_ops import rocm_aiter_ops
|
|
from vllm.logger import init_logger
|
|
from vllm.platforms import current_platform
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
# Take from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L84
|
|
def fp8_mqa_logits_torch(
|
|
q: torch.Tensor,
|
|
kv: tuple[torch.Tensor, torch.Tensor],
|
|
weights: torch.Tensor,
|
|
cu_seqlen_ks: torch.Tensor,
|
|
cu_seqlen_ke: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""Compute FP8 MQA logits for a single sequence without KV paging.
|
|
|
|
Args:
|
|
q: Query tensor of shape [M, H, D]. Casted to
|
|
`torch.float8_e4m3fn` by caller.
|
|
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
|
|
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
|
|
[N, 1]) with dtype `torch.float32`.
|
|
weights: weights of shape [M, H], dtype `torch.float32`.
|
|
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
|
|
shape [M], dtype int32.
|
|
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
|
|
shape [M], dtype int32.
|
|
|
|
Returns:
|
|
Logits tensor of shape [M, N], dtype `torch.float32`.
|
|
"""
|
|
kv, scale = kv
|
|
seq_len_kv = kv.shape[0]
|
|
k = kv.to(torch.bfloat16)
|
|
q = q.to(torch.bfloat16)
|
|
|
|
mask_lo = (
|
|
torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None]
|
|
)
|
|
mask_hi = (
|
|
torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
|
|
)
|
|
mask = mask_lo & mask_hi
|
|
|
|
score = torch.einsum("mhd,nd->hmn", q, k).float() * scale
|
|
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
|
|
logits = logits.masked_fill(~mask, float("-inf"))
|
|
|
|
return logits
|
|
|
|
|
|
def rocm_fp8_mqa_logits(
|
|
q: torch.Tensor,
|
|
kv: tuple[torch.Tensor, torch.Tensor],
|
|
weights: torch.Tensor,
|
|
cu_seqlen_ks: torch.Tensor,
|
|
cu_seqlen_ke: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""Compute FP8 MQA logits for a single sequence without KV paging.
|
|
|
|
Args:
|
|
q: Query tensor of shape [M, H, D]. Casted to
|
|
`torch.float8_e4m3fn` by caller.
|
|
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
|
|
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
|
|
[N, 1]) with dtype `torch.float32`.
|
|
weights: weights of shape [M, H], dtype `torch.float32`.
|
|
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
|
|
shape [M], dtype int32.
|
|
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
|
|
shape [M], dtype int32.
|
|
|
|
Returns:
|
|
Logits tensor of shape [M, N], dtype `torch.float32`.
|
|
"""
|
|
|
|
# TODO(ganyi): Temporarily workaround, will remove the module check and reference
|
|
# path after aiter merge this kernel into main
|
|
@lru_cache
|
|
def has_mqa_logits_module():
|
|
return importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None
|
|
|
|
if rocm_aiter_ops.is_enabled() and has_mqa_logits_module():
|
|
from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits
|
|
|
|
kv, scale = kv
|
|
return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke)
|
|
else:
|
|
return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
|
|
|
|
|
|
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156
|
|
def fp8_paged_mqa_logits_torch(
|
|
q: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
weights: torch.Tensor,
|
|
context_lens: torch.Tensor,
|
|
block_tables: torch.Tensor,
|
|
max_model_len: int,
|
|
):
|
|
from vllm.utils.math_utils import cdiv
|
|
|
|
fp8_dtype = current_platform.fp8_dtype()
|
|
batch_size, next_n, _, dim = q.size()
|
|
kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:]
|
|
scale = scale.contiguous().view(torch.float)
|
|
q = q.float()
|
|
kv_cache = kv_cache.view(fp8_dtype).float() * scale
|
|
num_block, block_size, _, dim = kv_cache.size()
|
|
logits = torch.full(
|
|
[batch_size * next_n, max_model_len],
|
|
float("-inf"),
|
|
device=q.device,
|
|
dtype=torch.float32,
|
|
)
|
|
context_lens = context_lens.tolist()
|
|
for i in range(batch_size):
|
|
context_len = context_lens[i]
|
|
q_offsets = torch.arange(context_len - next_n, context_len, device="cuda")
|
|
weight_slice = (
|
|
weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous()
|
|
)
|
|
for block_rk in range(cdiv(context_len, block_size)):
|
|
block_idx = block_tables[i][block_rk]
|
|
qx, kx = q[i], kv_cache[block_idx]
|
|
k_offsets = torch.arange(
|
|
block_rk * block_size, (block_rk + 1) * block_size, device="cuda"
|
|
)
|
|
mask = (k_offsets[None, :] < context_len) & (
|
|
k_offsets[None, :] <= q_offsets[:, None]
|
|
)
|
|
s = torch.where(
|
|
mask[None, :, :],
|
|
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
|
|
logits.dtype
|
|
),
|
|
float("-inf"),
|
|
)
|
|
s = torch.relu(s) * weight_slice[..., None]
|
|
s = s.sum(dim=0)
|
|
logits[
|
|
i * next_n : (i + 1) * next_n,
|
|
block_rk * block_size : (block_rk + 1) * block_size,
|
|
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf"))
|
|
return logits
|
|
|
|
|
|
def rocm_fp8_paged_mqa_logits(
|
|
q_fp8: torch.Tensor,
|
|
kv_cache_fp8: torch.Tensor,
|
|
weights: torch.Tensor,
|
|
context_lens: torch.Tensor,
|
|
block_tables: torch.Tensor,
|
|
schedule_metadata: torch.Tensor,
|
|
max_model_len: int,
|
|
) -> torch.Tensor:
|
|
"""Compute FP8 MQA logits using paged KV-cache.
|
|
|
|
Args:
|
|
q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
|
|
`torch.float8_e4m3fn` by caller.
|
|
kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
|
|
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
|
|
4 bytes per (block,pos) store the `float` dequant scale.
|
|
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
|
|
context_lens: Tensor of shape [B], dtype int32; effective context length
|
|
for each batch element.
|
|
block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
|
|
block indices to physical blocks in the paged cache.
|
|
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
|
|
used to distribute work across SMs.
|
|
max_model_len: Maximum sequence length used to size the logits output.
|
|
|
|
Returns:
|
|
Logits tensor of shape [B * next_n, max_model_len], dtype
|
|
`torch.float32`.
|
|
"""
|
|
|
|
if rocm_aiter_ops.is_enabled():
|
|
from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits_stage1
|
|
|
|
batch_size, next_n, heads, _ = q_fp8.shape
|
|
out_qk = torch.full(
|
|
(heads, batch_size * next_n, max_model_len),
|
|
float("-inf"),
|
|
device="cuda",
|
|
dtype=torch.float32,
|
|
)
|
|
deepgemm_fp8_paged_mqa_logits_stage1(
|
|
q_fp8,
|
|
kv_cache_fp8,
|
|
weights,
|
|
out_qk,
|
|
context_lens,
|
|
block_tables,
|
|
max_model_len,
|
|
)
|
|
return out_qk.sum(dim=0)
|
|
else:
|
|
return fp8_paged_mqa_logits_torch(
|
|
q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len
|
|
)
|