适配deepseekv3.2
Some checks failed
CI Monitor / ci-monitor (push) Has been cancelled
Release Docker Images Nightly (AMD) / publish (all, gfx942) (push) Has been cancelled
Release Docker Images Nightly (AMD) / publish (all, gfx942-rocm700) (push) Has been cancelled
Release Docker Images Nightly (AMD) / publish (all, gfx950) (push) Has been cancelled
Release Docker Images Nightly (AMD) / publish (srt, gfx942) (push) Has been cancelled
Release Docker Images Nightly (AMD) / publish (srt, gfx942-rocm700) (push) Has been cancelled
Release Docker Images Nightly (AMD) / publish (srt, gfx950) (push) Has been cancelled
Release Docker Images Nightly (Ascend NPU) / build (8.2.rc1, 910b) (push) Has been cancelled
Release Docker Images Nightly (Ascend NPU) / build (8.2.rc1, a3) (push) Has been cancelled
Build and Push Development Docker Images / build-dev-x86 (map[tag:dev type:all version:12.9.1]) (push) Has been cancelled
Build and Push Development Docker Images / build-blackwell-arm (map[tag:blackwell-cu129 type:blackwell_aarch version:12.9.1]) (push) Has been cancelled
Build and Push Development Docker Images / create-manifests (map[arm64_tag:blackwell-cu129-arm64 tag:dev-manifest x86_tag:dev]) (push) Has been cancelled
Nightly Test / nightly-test-eval-text-models (push) Has been cancelled
Nightly Test / nightly-test-perf-text-models (push) Has been cancelled
Nightly Test / nightly-test-eval-vlms (push) Has been cancelled
Nightly Test / nightly-test-perf-vlms (push) Has been cancelled
Nightly Test (AMD) / nightly-test (linux-mi300-gpu-2) (push) Has been cancelled
Nightly Test (AMD) / nightly-test (linux-mi325-gpu-2-nightly) (push) Has been cancelled
Close Inactive Issues / close-inactive-issues (push) Has been cancelled
Some checks failed
CI Monitor / ci-monitor (push) Has been cancelled
Release Docker Images Nightly (AMD) / publish (all, gfx942) (push) Has been cancelled
Release Docker Images Nightly (AMD) / publish (all, gfx942-rocm700) (push) Has been cancelled
Release Docker Images Nightly (AMD) / publish (all, gfx950) (push) Has been cancelled
Release Docker Images Nightly (AMD) / publish (srt, gfx942) (push) Has been cancelled
Release Docker Images Nightly (AMD) / publish (srt, gfx942-rocm700) (push) Has been cancelled
Release Docker Images Nightly (AMD) / publish (srt, gfx950) (push) Has been cancelled
Release Docker Images Nightly (Ascend NPU) / build (8.2.rc1, 910b) (push) Has been cancelled
Release Docker Images Nightly (Ascend NPU) / build (8.2.rc1, a3) (push) Has been cancelled
Build and Push Development Docker Images / build-dev-x86 (map[tag:dev type:all version:12.9.1]) (push) Has been cancelled
Build and Push Development Docker Images / build-blackwell-arm (map[tag:blackwell-cu129 type:blackwell_aarch version:12.9.1]) (push) Has been cancelled
Build and Push Development Docker Images / create-manifests (map[arm64_tag:blackwell-cu129-arm64 tag:dev-manifest x86_tag:dev]) (push) Has been cancelled
Nightly Test / nightly-test-eval-text-models (push) Has been cancelled
Nightly Test / nightly-test-perf-text-models (push) Has been cancelled
Nightly Test / nightly-test-eval-vlms (push) Has been cancelled
Nightly Test / nightly-test-perf-vlms (push) Has been cancelled
Nightly Test (AMD) / nightly-test (linux-mi300-gpu-2) (push) Has been cancelled
Nightly Test (AMD) / nightly-test (linux-mi325-gpu-2-nightly) (push) Has been cancelled
Close Inactive Issues / close-inactive-issues (push) Has been cancelled
This commit is contained in:
121
python/sglang/srt/layers/attention/native_mla.py
Normal file
121
python/sglang/srt/layers/attention/native_mla.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import math
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
import torch
|
||||
|
||||
def cdiv(x: int, y: int):
|
||||
return (x+y-1) // y
|
||||
|
||||
def native_mla_sparse_fwd(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
sm_scale: float,
|
||||
d_v: int = 512,) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
|
||||
s_q, _, d_qk = q.size()
|
||||
s_kv = kv.size(0)
|
||||
topk = indices.size(-1)
|
||||
|
||||
def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor:
|
||||
return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e)
|
||||
|
||||
indices = indices[:, 0, :] # [s_q, topk]
|
||||
invalid_indices_mask = (indices < 0) | (indices >= s_kv)
|
||||
qs = q.float() # [s_q, h_q, d_qk]
|
||||
kvs = kv[ :, 0, :].float() # [s_kv, d_qk]
|
||||
|
||||
kvs = torch.index_select(kvs, 0, indices.masked_fill(invalid_indices_mask, 0).flatten()).view(s_q, topk, d_qk) # [s_q, topk, d_qk]
|
||||
attn_score = qs @ kvs.transpose(1, 2) # [s_q, h_q, topk]
|
||||
attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float('-inf'))
|
||||
attn_score *= sm_scale * math.log2(math.e)
|
||||
max_logits = torch.max(attn_score, dim=-1)[0] # [s_q, h_q]
|
||||
lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q]
|
||||
attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk]
|
||||
result = attn_score @ kvs[:, :, :d_v]
|
||||
return (max_logits, lse, result)
|
||||
|
||||
|
||||
|
||||
def native_mla_with_kvcache(
|
||||
q: torch.Tensor, # [batch_size, s_q, h_q, d]
|
||||
blocked_k: torch.Tensor, # [?, block_size, h_kv, d]
|
||||
block_table: torch.Tensor, # [batch_size, ?]
|
||||
cache_seqlens: torch.Tensor, # [batch_size]
|
||||
dv: int,
|
||||
is_causal: bool,
|
||||
indices: Optional[torch.Tensor] = None # [batch_size, s_q, topk]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
A reference implementation in PyTorch
|
||||
"""
|
||||
def get_topk_attn_mask(s_q: int, s_k: int, indices: torch.Tensor):
|
||||
mask = torch.zeros(s_q, s_k, dtype=torch.bool)
|
||||
for i in range(s_q):
|
||||
cur_indices = indices[i]
|
||||
valid_indices = cur_indices[cur_indices != -1]
|
||||
mask[i, valid_indices] = True
|
||||
return mask
|
||||
|
||||
def scaled_dot_product_attention(
|
||||
batch_idx: int,
|
||||
query: torch.Tensor, # [h_q, s_q, d]
|
||||
kv: torch.Tensor, # [h_kv, s_k, d]
|
||||
dv: int,
|
||||
is_causal,
|
||||
indices: Optional[torch.Tensor], # [s_q, topk]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
h_q = query.size(0)
|
||||
h_kv = kv.size(0)
|
||||
s_q = query.shape[-2]
|
||||
s_k = kv.shape[-2]
|
||||
query = query.float()
|
||||
kv = kv.float()
|
||||
if h_kv != 1:
|
||||
kv = kv.repeat_interleave(h_q // h_kv, dim=0)
|
||||
kv[kv != kv] = 0.0
|
||||
attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k]
|
||||
if (is_causal and query.size(1) > 1) or indices is not None:
|
||||
mask = torch.ones(s_q, s_k, dtype=torch.bool)
|
||||
if is_causal:
|
||||
assert indices is None
|
||||
mask = mask.tril(diagonal=s_k - s_q)
|
||||
if indices is not None:
|
||||
mask &= get_topk_attn_mask(s_q, s_k, indices)
|
||||
attn_bias = torch.zeros(s_q, s_k, dtype=torch.float)
|
||||
attn_bias.masked_fill_(mask.logical_not(), float("-inf"))
|
||||
attn_weight += attn_bias.to(q.dtype)
|
||||
attn_weight /= math.sqrt(query.size(-1))
|
||||
lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q]
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
|
||||
output = attn_weight @ kv[..., :dv] # [h_q, s_q, dv]
|
||||
# Correct for q tokens which has no attendable k
|
||||
lonely_q_mask = (lse == float("-inf"))
|
||||
output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0
|
||||
lse[lonely_q_mask] = float("+inf")
|
||||
|
||||
return output, lse
|
||||
|
||||
b, s_q, h_q, d = q.size()
|
||||
block_size = blocked_k.size(1)
|
||||
h_kv = blocked_k.size(2)
|
||||
cache_seqlens_cpu = cache_seqlens.cpu()
|
||||
out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
|
||||
lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32)
|
||||
for i in range(b):
|
||||
cur_len = cache_seqlens_cpu[i].item()
|
||||
cur_num_blocks = cdiv(cur_len, block_size)
|
||||
cur_block_indices = block_table[i][0: cur_num_blocks]
|
||||
cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...]
|
||||
cur_out, cur_lse = scaled_dot_product_attention(
|
||||
i,
|
||||
q[i].transpose(0, 1),
|
||||
cur_kv.transpose(0, 1),
|
||||
dv,
|
||||
is_causal,
|
||||
indices[i] if indices is not None else None
|
||||
)
|
||||
out_ref[i] = cur_out.transpose(0, 1)
|
||||
lse_ref[i] = cur_lse
|
||||
out_ref = out_ref.to(torch.bfloat16)
|
||||
return out_ref, lse_ref
|
||||
135
python/sglang/srt/layers/attention/nsa/fallback_fp8.py
Normal file
135
python/sglang/srt/layers/attention/nsa/fallback_fp8.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# fallback_fp8.py
|
||||
# PyTorch fallback implementation for DeepGEMM-like fp8 logits ops
|
||||
from sglang.srt.utils import ceil_div
|
||||
import torch
|
||||
|
||||
@torch.no_grad()
|
||||
def fallback_fp8_mqa_logits(q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
ks: torch.Tensor,
|
||||
ke: torch.Tensor, cost_only: bool = False) -> torch.Tensor:
|
||||
seq_len_kv = kv.shape[0]
|
||||
|
||||
if cost_only:
|
||||
start = ks.clamp(min=0, max=seq_len_kv)
|
||||
end = ke.clamp(min=0, max=seq_len_kv)
|
||||
count_ones_per_row = (end - start).clamp(min=0)
|
||||
return count_ones_per_row.sum()
|
||||
|
||||
k = kv
|
||||
q = q.float()
|
||||
k = k.float()
|
||||
|
||||
mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= ks[:, None]
|
||||
mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < ke[:, None]
|
||||
mask = mask_lo & mask_hi
|
||||
|
||||
score = torch.einsum('mhd,nd->hmn', q, k)
|
||||
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
|
||||
logits = logits.masked_fill(~mask, float('-inf'))
|
||||
|
||||
#cost = mask.sum()
|
||||
return logits
|
||||
|
||||
# """
|
||||
# PyTorch fallback for fp8_mqa_logits.
|
||||
# No real fp8 used, just FP32.
|
||||
# Args:
|
||||
# q: (M, H, D) query
|
||||
# k: (N, D) key
|
||||
# weights: (M, H)
|
||||
# ks: (M,) int32
|
||||
# ke: (M,) int32
|
||||
# Returns:
|
||||
# logits: (M, N) with -inf outside of valid range
|
||||
# """
|
||||
# M, H, D = q.shape
|
||||
# N = k[0].shape[0]
|
||||
# logits = torch.full((M, N), float("-inf"), dtype=torch.float32, device=q.device)
|
||||
|
||||
# # for i in range(M):
|
||||
# # start = max(ks[i].item(), 0)
|
||||
# # end = min(ke[i].item(), N)
|
||||
# # if start >= end:
|
||||
# # continue
|
||||
# # qi = q[i] # (H, D)
|
||||
# # ki = k[start:end] # (L, D)
|
||||
# # sim = torch.matmul(qi, ki.T) # (H, L)
|
||||
# # weighted_sim = (sim.relu() * weights[i].unsqueeze(-1)).sum(dim=0) # (L,)
|
||||
# # logits[i, start:end] = weighted_sim
|
||||
# return logits
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def fallback_fp8_paged_mqa_logits(q: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
max_model_len: int) -> torch.Tensor:
|
||||
|
||||
batch_size, next_n, heads, dim = q.size()
|
||||
num_block, block_size, _, dim = kv_cache.size()
|
||||
logits = torch.full([batch_size * next_n, max_model_len], float('-inf'), device=q.device, dtype=torch.float32)
|
||||
context_lens = context_lens.tolist()
|
||||
for i in range(batch_size):
|
||||
context_len = context_lens[i]
|
||||
q_offsets = torch.arange(context_len - next_n, context_len, device=q.device)
|
||||
weight_slice = weights[i * next_n:(i + 1) * next_n, :].transpose(0, 1).contiguous()
|
||||
for block_rk in range(ceil_div(context_len, block_size)):
|
||||
block_idx = block_tables[i][block_rk]
|
||||
qx, kx = q[i], kv_cache[block_idx]
|
||||
k_offsets = torch.arange(block_rk * block_size, (block_rk + 1) * block_size, device=q.device)
|
||||
mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :] <= q_offsets[:, None])
|
||||
s = torch.where(mask[None, :, :], (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(logits.dtype), float('-inf'))
|
||||
s = torch.relu(s) * weight_slice[..., None]
|
||||
s = s.sum(dim=0)
|
||||
logits[i * next_n:(i + 1) * next_n, block_rk * block_size: (block_rk + 1) * block_size] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float('-inf'))
|
||||
return logits
|
||||
|
||||
|
||||
"""
|
||||
PyTorch fallback for fp8_paged_mqa_logits.
|
||||
No real fp8 used, just FP32.
|
||||
Args:
|
||||
q: (B, N, H, D)
|
||||
kv_cache: (num_blocks, block_size, 1, D)
|
||||
weights: (B * N, H)
|
||||
context_lens: (B,)
|
||||
block_tables: (B, max_blocks)
|
||||
max_model_len: int
|
||||
Returns:
|
||||
logits: (B * N, max_model_len)
|
||||
"""
|
||||
B, N, H, D = q.shape
|
||||
block_size = kv_cache.shape[1]
|
||||
logits = torch.full((B * N, max_model_len), float("-inf"), dtype=torch.float32, device=q.device)
|
||||
|
||||
for i in range(B):
|
||||
ctx_len = context_lens[i].item()
|
||||
q_offsets = torch.arange(ctx_len - N, ctx_len, device=q.device)
|
||||
weight_slice = weights[i * N:(i + 1) * N, :].transpose(0, 1).contiguous()
|
||||
|
||||
for br in range((ctx_len + block_size - 1) // block_size):
|
||||
blk_idx = block_tables[i, br].item()
|
||||
if blk_idx < 0:
|
||||
continue
|
||||
qx = q[i] # (N, H, D)
|
||||
kx = kv_cache[blk_idx] # (block_size, 1, D)
|
||||
kx = kx.squeeze(1) # (block_size, D)
|
||||
k_offsets = torch.arange(br * block_size, (br + 1) * block_size, device=q.device)
|
||||
|
||||
mask = (k_offsets[None, :] < ctx_len) & (k_offsets[None, :] <= q_offsets[:, None]) # (N, block_size)
|
||||
s = torch.where(mask[None, :, :],
|
||||
torch.einsum('nhd,ld->hnl', qx, kx),
|
||||
torch.full((H, N, block_size), float("-inf"), device=q.device))
|
||||
s = s.relu() * weight_slice[..., None]
|
||||
logits_slice = s.sum(dim=0) # (N, block_size)
|
||||
|
||||
mask_block = (k_offsets[None, :] <= q_offsets[:, None])
|
||||
logits[i * N:(i + 1) * N, br * block_size:(br + 1) * block_size] = \
|
||||
torch.where(mask_block, logits_slice, float("-inf"))
|
||||
|
||||
return logits
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
|
||||
from sglang.srt.layers.attention.nsa.fallback_fp8 import fallback_fp8_mqa_logits, fallback_fp8_paged_mqa_logits
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
@@ -14,7 +15,7 @@ from sglang.srt.utils import add_prefix, is_npu
|
||||
|
||||
if not is_npu():
|
||||
from sglang.srt.layers.attention.nsa.tilelang_kernel import act_quant
|
||||
import deep_gemm
|
||||
#import deep_gemm
|
||||
|
||||
from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM, NSA_USE_REAL_INDEXER
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_group
|
||||
@@ -27,14 +28,14 @@ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import add_prefix, align, is_cuda
|
||||
|
||||
try:
|
||||
import deep_gemm_v32
|
||||
except ImportError as e:
|
||||
print("Error when importing deep_gemm_v32, try deep_gemm")
|
||||
try:
|
||||
import deep_gemm as deep_gemm_v32
|
||||
except ImportError as e:
|
||||
print("Error when importing deep_gemm, skip")
|
||||
# try:
|
||||
# import deep_gemm_v32
|
||||
# except ImportError as e:
|
||||
# print("Error when importing deep_gemm_v32, try deep_gemm")
|
||||
# try:
|
||||
# import deep_gemm as deep_gemm_v32
|
||||
# except ImportError as e:
|
||||
# print("Error when importing deep_gemm, skip")
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -81,16 +82,47 @@ class BaseIndexerMetadata(ABC):
|
||||
Don't assume it is the topk indices of the input logits.
|
||||
"""
|
||||
|
||||
def hadamard_transform_pytorch(x: torch.Tensor, scale: float) -> torch.Tensor:
|
||||
"""
|
||||
A native PyTorch implementation of the Fast Hadamard Transform that mimics
|
||||
the behavior of the custom CUDA kernel's call signature.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor of shape (*, N), where N is a power of 2.
|
||||
scale (float): The normalization factor to multiply the result by.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The Hadamard transformed tensor.
|
||||
"""
|
||||
# Base case for recursion
|
||||
if x.shape[-1] == 1:
|
||||
return x
|
||||
|
||||
# Split the tensor into two halves
|
||||
half_size = x.shape[-1] // 2
|
||||
a = x[..., :half_size]
|
||||
b = x[..., half_size:]
|
||||
|
||||
# Recursive calls
|
||||
a_transformed = hadamard_transform_pytorch(a, scale=1.0) # No scaling in intermediate steps
|
||||
b_transformed = hadamard_transform_pytorch(b, scale=1.0) # No scaling in intermediate steps
|
||||
|
||||
# Combine the results
|
||||
combined = torch.cat([a_transformed + b_transformed, a_transformed - b_transformed], dim=-1)
|
||||
|
||||
# Apply the scale only at the final step
|
||||
return combined * scale
|
||||
|
||||
|
||||
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
|
||||
assert x.dtype == torch.bfloat16
|
||||
from fast_hadamard_transform import hadamard_transform
|
||||
#from fast_hadamard_transform import hadamard_transform
|
||||
|
||||
hidden_size = x.size(-1)
|
||||
assert (
|
||||
hidden_size & (hidden_size - 1)
|
||||
) == 0, "Hidden size must be a power of 2 for Hadamard transform."
|
||||
return hadamard_transform(x, scale=hidden_size**-0.5)
|
||||
return hadamard_transform_pytorch(x, scale=hidden_size**-0.5)
|
||||
|
||||
|
||||
class V32LayerNorm(nn.Module):
|
||||
@@ -140,7 +172,7 @@ class Indexer(CustomOp):
|
||||
self.layer_id = layer_id
|
||||
self.alt_stream = alt_stream
|
||||
if not is_npu():
|
||||
self.sm_count = deep_gemm.get_num_sms()
|
||||
self.sm_count = torch.cuda.get_device_properties(0).multi_processor_count
|
||||
self.half_device_sm_count = align(self.sm_count // 2, 8)
|
||||
|
||||
self.wq_b = ReplicatedLinear(
|
||||
@@ -273,9 +305,7 @@ class Indexer(CustomOp):
|
||||
k_rope, _ = torch.split(
|
||||
key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
q_rope, k_rope = self.rotary_emb(positions, q_rope, k_rope)
|
||||
|
||||
query[..., : self.rope_head_dim] = q_rope
|
||||
key[..., : self.rope_head_dim] = k_rope
|
||||
|
||||
@@ -323,9 +353,9 @@ class Indexer(CustomOp):
|
||||
blocksize = page_size
|
||||
seqlens_32 = metadata.get_seqlens_int32()
|
||||
# NOTE(dark): 132 is SM count on H200/B200, not magic number
|
||||
schedule_metadata = deep_gemm_v32.get_paged_mqa_logits_metadata(
|
||||
seqlens_32, blocksize, self.sm_count
|
||||
)
|
||||
# schedule_metadata = deep_gemm_v32.get_paged_mqa_logits_metadata(
|
||||
# seqlens_32, blocksize, self.sm_count
|
||||
# )
|
||||
|
||||
assert len(q_fp8.shape) == 3
|
||||
q_fp8 = q_fp8.unsqueeze(1) # the next_n dim is 1 now
|
||||
@@ -339,15 +369,13 @@ class Indexer(CustomOp):
|
||||
assert len(weights.shape) == 3
|
||||
weights = weights.squeeze(2)
|
||||
|
||||
logits = deep_gemm_v32.fp8_paged_mqa_logits(
|
||||
logits = fallback_fp8_paged_mqa_logits(
|
||||
q_fp8,
|
||||
kv_cache_fp8,
|
||||
weights,
|
||||
seqlens_32,
|
||||
block_tables,
|
||||
schedule_metadata,
|
||||
max_seq_len,
|
||||
clean_logits=False,
|
||||
)
|
||||
|
||||
# NOTE(dark): logits should be cleaned in topk_transform
|
||||
@@ -408,13 +436,12 @@ class Indexer(CustomOp):
|
||||
seq_lens_expanded = metadata.get_seqlens_expanded()
|
||||
ke = ks + seq_lens_expanded
|
||||
|
||||
logits = deep_gemm_v32.fp8_mqa_logits(
|
||||
logits = fallback_fp8_mqa_logits(
|
||||
q_fp8,
|
||||
kv_fp8,
|
||||
k_fp8,
|
||||
weights,
|
||||
ks,
|
||||
ke,
|
||||
clean_logits=False,
|
||||
ke
|
||||
)
|
||||
|
||||
assert logits.shape[0] == len(seq_lens_expanded)
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import tilelang
|
||||
import tilelang.language as T
|
||||
# import tilelang
|
||||
# import tilelang.language as T
|
||||
import torch
|
||||
|
||||
tilelang.set_log_level("WARNING")
|
||||
# tilelang.set_log_level("WARNING")
|
||||
|
||||
pass_configs = {
|
||||
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
|
||||
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
|
||||
tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True,
|
||||
}
|
||||
# pass_configs = {
|
||||
# tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
|
||||
# tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
|
||||
# tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True,
|
||||
# }
|
||||
|
||||
BF16 = "bfloat16"
|
||||
FP8 = "float8_e4m3"
|
||||
FP32 = "float32"
|
||||
|
||||
|
||||
'''
|
||||
def fast_log2_ceil(x):
|
||||
bits_x = T.reinterpret("uint32", x)
|
||||
exp_x = (bits_x >> 23) & 0xFF
|
||||
@@ -32,7 +32,6 @@ def fast_pow2(x):
|
||||
def fast_round_scale(amax, fp8_max_inv):
|
||||
return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
|
||||
|
||||
|
||||
@tilelang.jit(pass_configs=pass_configs)
|
||||
def act_quant_kernel(
|
||||
N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False
|
||||
@@ -83,7 +82,6 @@ def act_quant_kernel(
|
||||
|
||||
return act_quant_kernel_
|
||||
|
||||
|
||||
def act_quant(
|
||||
x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -753,7 +751,6 @@ def sparse_attention_fwd_kernel_v2(
|
||||
|
||||
return main
|
||||
|
||||
|
||||
def tilelang_sparse_fwd(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
@@ -772,3 +769,45 @@ def tilelang_sparse_fwd(
|
||||
num_heads, d_v, tail_dim, topk, sm_scale=sm_scale
|
||||
)
|
||||
return kernel(q.unsqueeze(0), kv.unsqueeze(0), indices.unsqueeze(0)) # type: ignore
|
||||
'''
|
||||
def act_quant(
|
||||
x: torch.Tensor,
|
||||
block_size: int = 128,
|
||||
scale_fmt: Optional[str] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
PyTorch fallback for act_quant
|
||||
Block-wise FP8 E4M3 quantization
|
||||
"""
|
||||
if not x.is_contiguous():
|
||||
x = x.contiguous()
|
||||
|
||||
N = x.size(-1)
|
||||
assert N % block_size == 0, f"Last dim {N} must be divisible by block_size={block_size}"
|
||||
|
||||
# Reshape to blocks
|
||||
x_2d = x.view(-1, N)
|
||||
x_blocks = x_2d.view(-1, block_size)
|
||||
|
||||
# Compute absmax per block
|
||||
amax = x_blocks.abs().amax(dim=1, keepdim=True).clamp(min=1e-4)
|
||||
|
||||
# FP8 E4M3 max value is ~448
|
||||
fp8_max = 448.0
|
||||
scale = amax / fp8_max
|
||||
|
||||
if scale_fmt is not None:
|
||||
# Simulate rounded scale (power-of-2 rounding)
|
||||
scale = torch.round(scale * 256) / 256
|
||||
|
||||
# Quantize and clamp
|
||||
y_blocks = torch.clamp(torch.round(x_blocks / scale), -fp8_max, fp8_max)
|
||||
|
||||
# Convert to FP8
|
||||
q = y_blocks.view_as(x_2d).to(torch.float8_e4m3fn)
|
||||
|
||||
# Reshape scale
|
||||
s = scale.view(x_2d.size(0), N // block_size).to(torch.float32)
|
||||
s = s.view(*x.shape[:-1], N // block_size)
|
||||
|
||||
return q.view_as(x), s
|
||||
|
||||
@@ -105,7 +105,7 @@ def transform_index_page_table_decode_ref(
|
||||
torch.gather(
|
||||
page_table,
|
||||
dim=1,
|
||||
index=topk_indices.clamp(min=0),
|
||||
index=topk_indices.clamp(min=0).long(),
|
||||
out=result,
|
||||
)
|
||||
result[topk_indices < 0] = -1
|
||||
|
||||
@@ -10,7 +10,6 @@ from typing import (
|
||||
Tuple,
|
||||
TypeAlias,
|
||||
Union,
|
||||
override,
|
||||
)
|
||||
|
||||
import torch
|
||||
@@ -101,19 +100,15 @@ class NSAMetadata:
|
||||
class NSAIndexerMetadata(BaseIndexerMetadata):
|
||||
attn_metadata: NSAMetadata
|
||||
|
||||
@override
|
||||
def get_seqlens_int32(self) -> torch.Tensor:
|
||||
return self.attn_metadata.cache_seqlens_int32
|
||||
|
||||
@override
|
||||
def get_page_table_64(self) -> torch.Tensor:
|
||||
return self.attn_metadata.real_page_table
|
||||
|
||||
@override
|
||||
def get_seqlens_expanded(self) -> torch.Tensor:
|
||||
return self.attn_metadata.nsa_seqlens_expanded
|
||||
|
||||
@override
|
||||
def topk_transform(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
@@ -524,21 +519,25 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
extend_lens_cpu=metadata.nsa_extend_seq_lens_list,
|
||||
page_size=1,
|
||||
)
|
||||
if NSA_PREFILL_IMPL == "tilelang":
|
||||
from sglang.srt.layers.attention.nsa.tilelang_kernel import (
|
||||
tilelang_sparse_fwd,
|
||||
)
|
||||
# if NSA_PREFILL_IMPL == "tilelang":
|
||||
# from sglang.srt.layers.attention.nsa.tilelang_kernel import (
|
||||
# tilelang_sparse_fwd,
|
||||
# )
|
||||
|
||||
if q_rope is not None:
|
||||
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||
return self._forward_tilelang(
|
||||
q_all=q_all,
|
||||
kv_cache=kv_cache,
|
||||
page_table_1=page_table_1,
|
||||
sm_scale=layer.scaling,
|
||||
v_head_dim=layer.v_head_dim,
|
||||
)
|
||||
elif NSA_PREFILL_IMPL == "flashmla_prefill":
|
||||
# if q_rope is not None:
|
||||
# q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||
# return self._forward_tilelang(
|
||||
# q_all=q_all,
|
||||
# kv_cache=kv_cache,
|
||||
# page_table_1=page_table_1,
|
||||
# sm_scale=layer.scaling,
|
||||
# v_head_dim=layer.v_head_dim,
|
||||
# )
|
||||
# elif NSA_PREFILL_IMPL == "flashmla_prefill":
|
||||
|
||||
|
||||
# Skip tilelang dependencies
|
||||
if NSA_PREFILL_IMPL == "tilelang" or NSA_PREFILL_IMPL == "flashmla_prefill":
|
||||
if q_rope is not None:
|
||||
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||
return self._forward_flashmla_prefill(
|
||||
@@ -733,9 +732,9 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
page_table_1: torch.Tensor,
|
||||
sm_scale: float,
|
||||
) -> torch.Tensor:
|
||||
from flash_mla import flash_mla_sparse_fwd
|
||||
|
||||
o, _, _ = flash_mla_sparse_fwd(
|
||||
#from flash_mla import flash_mla_sparse_fwd
|
||||
from sglang.srt.layers.attention.native_mla import native_mla_sparse_fwd
|
||||
_, _, o = native_mla_sparse_fwd(
|
||||
q=q_all,
|
||||
kv=kv_cache,
|
||||
indices=page_table_1.unsqueeze(1),
|
||||
@@ -756,8 +755,8 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
topk_indices,
|
||||
block_table,
|
||||
) -> torch.Tensor:
|
||||
from flash_mla import flash_mla_with_kvcache
|
||||
|
||||
#from flash_mla import flash_mla_with_kvcache
|
||||
from sglang.srt.layers.attention.native_mla import native_mla_with_kvcache
|
||||
cache_seqlens = metadata.nsa_cache_seqlens_int32
|
||||
|
||||
# TODO the 2nd dim is seq_len_q, need to be >1 when MTP
|
||||
@@ -769,7 +768,7 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
# inefficiently quantize the whole cache
|
||||
kv_cache = quantize_k_cache(kv_cache)
|
||||
|
||||
o, _ = flash_mla_with_kvcache(
|
||||
o, _ = native_mla_with_kvcache(
|
||||
q=q_all,
|
||||
k_cache=kv_cache,
|
||||
cache_seqlens=cache_seqlens,
|
||||
|
||||
@@ -136,21 +136,21 @@ class RMSNorm(CustomOp):
|
||||
# NOTE: Remove this if aiter kernel supports discontinuous input
|
||||
x = x.contiguous()
|
||||
if residual is not None:
|
||||
if _vllm_version < Version("0.9"):
|
||||
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
|
||||
return x, residual
|
||||
else:
|
||||
residual_out = torch.empty_like(x)
|
||||
output = torch.empty_like(x)
|
||||
fused_add_rms_norm(
|
||||
output,
|
||||
x,
|
||||
residual_out,
|
||||
residual,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return output, residual_out
|
||||
#if _vllm_version < Version("0.9"):
|
||||
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
|
||||
return x, residual
|
||||
# else:
|
||||
# residual_out = torch.empty_like(x)
|
||||
# output = torch.empty_like(x)
|
||||
# fused_add_rms_norm(
|
||||
# output,
|
||||
# x,
|
||||
# residual_out,
|
||||
# residual,
|
||||
# self.weight.data,
|
||||
# self.variance_epsilon,
|
||||
# )
|
||||
# return output, residual_out
|
||||
out = torch.empty_like(x)
|
||||
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
||||
return out
|
||||
|
||||
@@ -765,7 +765,10 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
|
||||
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
|
||||
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
|
||||
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
|
||||
cos_for_key = cos[:, 0, ...]
|
||||
sin_for_key = sin[:, 0, ...]
|
||||
key_rot = key_rot * cos_for_key + rotate_fn(key_rot) * sin_for_key
|
||||
#key_rot = key_rot * cos + rotate_fn(key_rot) * sin
|
||||
|
||||
if self.rotary_dim < self.head_size:
|
||||
query = torch.cat((query_rot, query_pass), dim=-1)
|
||||
|
||||
Reference in New Issue
Block a user