diff --git a/python/sglang/srt/layers/attention/native_mla.py b/python/sglang/srt/layers/attention/native_mla.py new file mode 100644 index 000000000..68e9fa7aa --- /dev/null +++ b/python/sglang/srt/layers/attention/native_mla.py @@ -0,0 +1,121 @@ +import math +from typing import Optional, Tuple, List + +import torch + +def cdiv(x: int, y: int): + return (x+y-1) // y + +def native_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512,) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + s_q, _, d_qk = q.size() + s_kv = kv.size(0) + topk = indices.size(-1) + + def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor: + return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e) + + indices = indices[:, 0, :] # [s_q, topk] + invalid_indices_mask = (indices < 0) | (indices >= s_kv) + qs = q.float() # [s_q, h_q, d_qk] + kvs = kv[ :, 0, :].float() # [s_kv, d_qk] + + kvs = torch.index_select(kvs, 0, indices.masked_fill(invalid_indices_mask, 0).flatten()).view(s_q, topk, d_qk) # [s_q, topk, d_qk] + attn_score = qs @ kvs.transpose(1, 2) # [s_q, h_q, topk] + attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float('-inf')) + attn_score *= sm_scale * math.log2(math.e) + max_logits = torch.max(attn_score, dim=-1)[0] # [s_q, h_q] + lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q] + attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk] + result = attn_score @ kvs[:, :, :d_v] + return (max_logits, lse, result) + + + +def native_mla_with_kvcache( + q: torch.Tensor, # [batch_size, s_q, h_q, d] + blocked_k: torch.Tensor, # [?, block_size, h_kv, d] + block_table: torch.Tensor, # [batch_size, ?] + cache_seqlens: torch.Tensor, # [batch_size] + dv: int, + is_causal: bool, + indices: Optional[torch.Tensor] = None # [batch_size, s_q, topk] +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + A reference implementation in PyTorch + """ + def get_topk_attn_mask(s_q: int, s_k: int, indices: torch.Tensor): + mask = torch.zeros(s_q, s_k, dtype=torch.bool) + for i in range(s_q): + cur_indices = indices[i] + valid_indices = cur_indices[cur_indices != -1] + mask[i, valid_indices] = True + return mask + + def scaled_dot_product_attention( + batch_idx: int, + query: torch.Tensor, # [h_q, s_q, d] + kv: torch.Tensor, # [h_kv, s_k, d] + dv: int, + is_causal, + indices: Optional[torch.Tensor], # [s_q, topk] + ) -> Tuple[torch.Tensor, torch.Tensor]: + h_q = query.size(0) + h_kv = kv.size(0) + s_q = query.shape[-2] + s_k = kv.shape[-2] + query = query.float() + kv = kv.float() + if h_kv != 1: + kv = kv.repeat_interleave(h_q // h_kv, dim=0) + kv[kv != kv] = 0.0 + attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k] + if (is_causal and query.size(1) > 1) or indices is not None: + mask = torch.ones(s_q, s_k, dtype=torch.bool) + if is_causal: + assert indices is None + mask = mask.tril(diagonal=s_k - s_q) + if indices is not None: + mask &= get_topk_attn_mask(s_q, s_k, indices) + attn_bias = torch.zeros(s_q, s_k, dtype=torch.float) + attn_bias.masked_fill_(mask.logical_not(), float("-inf")) + attn_weight += attn_bias.to(q.dtype) + attn_weight /= math.sqrt(query.size(-1)) + lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q] + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + output = attn_weight @ kv[..., :dv] # [h_q, s_q, dv] + # Correct for q tokens which has no attendable k + lonely_q_mask = (lse == float("-inf")) + output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0 + lse[lonely_q_mask] = float("+inf") + + return output, lse + + b, s_q, h_q, d = q.size() + block_size = blocked_k.size(1) + h_kv = blocked_k.size(2) + cache_seqlens_cpu = cache_seqlens.cpu() + out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32) + for i in range(b): + cur_len = cache_seqlens_cpu[i].item() + cur_num_blocks = cdiv(cur_len, block_size) + cur_block_indices = block_table[i][0: cur_num_blocks] + cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...] + cur_out, cur_lse = scaled_dot_product_attention( + i, + q[i].transpose(0, 1), + cur_kv.transpose(0, 1), + dv, + is_causal, + indices[i] if indices is not None else None + ) + out_ref[i] = cur_out.transpose(0, 1) + lse_ref[i] = cur_lse + out_ref = out_ref.to(torch.bfloat16) + return out_ref, lse_ref diff --git a/python/sglang/srt/layers/attention/nsa/fallback_fp8.py b/python/sglang/srt/layers/attention/nsa/fallback_fp8.py new file mode 100644 index 000000000..febec7782 --- /dev/null +++ b/python/sglang/srt/layers/attention/nsa/fallback_fp8.py @@ -0,0 +1,135 @@ +# fallback_fp8.py +# PyTorch fallback implementation for DeepGEMM-like fp8 logits ops +from sglang.srt.utils import ceil_div +import torch + +@torch.no_grad() +def fallback_fp8_mqa_logits(q: torch.Tensor, + kv: torch.Tensor, + weights: torch.Tensor, + ks: torch.Tensor, + ke: torch.Tensor, cost_only: bool = False) -> torch.Tensor: + seq_len_kv = kv.shape[0] + + if cost_only: + start = ks.clamp(min=0, max=seq_len_kv) + end = ke.clamp(min=0, max=seq_len_kv) + count_ones_per_row = (end - start).clamp(min=0) + return count_ones_per_row.sum() + + k = kv + q = q.float() + k = k.float() + + mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= ks[:, None] + mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < ke[:, None] + mask = mask_lo & mask_hi + + score = torch.einsum('mhd,nd->hmn', q, k) + logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float('-inf')) + + #cost = mask.sum() + return logits + + # """ + # PyTorch fallback for fp8_mqa_logits. + # No real fp8 used, just FP32. + # Args: + # q: (M, H, D) query + # k: (N, D) key + # weights: (M, H) + # ks: (M,) int32 + # ke: (M,) int32 + # Returns: + # logits: (M, N) with -inf outside of valid range + # """ + # M, H, D = q.shape + # N = k[0].shape[0] + # logits = torch.full((M, N), float("-inf"), dtype=torch.float32, device=q.device) + + # # for i in range(M): + # # start = max(ks[i].item(), 0) + # # end = min(ke[i].item(), N) + # # if start >= end: + # # continue + # # qi = q[i] # (H, D) + # # ki = k[start:end] # (L, D) + # # sim = torch.matmul(qi, ki.T) # (H, L) + # # weighted_sim = (sim.relu() * weights[i].unsqueeze(-1)).sum(dim=0) # (L,) + # # logits[i, start:end] = weighted_sim + # return logits + + +@torch.no_grad() +def fallback_fp8_paged_mqa_logits(q: torch.Tensor, + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int) -> torch.Tensor: + + batch_size, next_n, heads, dim = q.size() + num_block, block_size, _, dim = kv_cache.size() + logits = torch.full([batch_size * next_n, max_model_len], float('-inf'), device=q.device, dtype=torch.float32) + context_lens = context_lens.tolist() + for i in range(batch_size): + context_len = context_lens[i] + q_offsets = torch.arange(context_len - next_n, context_len, device=q.device) + weight_slice = weights[i * next_n:(i + 1) * next_n, :].transpose(0, 1).contiguous() + for block_rk in range(ceil_div(context_len, block_size)): + block_idx = block_tables[i][block_rk] + qx, kx = q[i], kv_cache[block_idx] + k_offsets = torch.arange(block_rk * block_size, (block_rk + 1) * block_size, device=q.device) + mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :] <= q_offsets[:, None]) + s = torch.where(mask[None, :, :], (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(logits.dtype), float('-inf')) + s = torch.relu(s) * weight_slice[..., None] + s = s.sum(dim=0) + logits[i * next_n:(i + 1) * next_n, block_rk * block_size: (block_rk + 1) * block_size] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float('-inf')) + return logits + + + """ + PyTorch fallback for fp8_paged_mqa_logits. + No real fp8 used, just FP32. + Args: + q: (B, N, H, D) + kv_cache: (num_blocks, block_size, 1, D) + weights: (B * N, H) + context_lens: (B,) + block_tables: (B, max_blocks) + max_model_len: int + Returns: + logits: (B * N, max_model_len) + """ + B, N, H, D = q.shape + block_size = kv_cache.shape[1] + logits = torch.full((B * N, max_model_len), float("-inf"), dtype=torch.float32, device=q.device) + + for i in range(B): + ctx_len = context_lens[i].item() + q_offsets = torch.arange(ctx_len - N, ctx_len, device=q.device) + weight_slice = weights[i * N:(i + 1) * N, :].transpose(0, 1).contiguous() + + for br in range((ctx_len + block_size - 1) // block_size): + blk_idx = block_tables[i, br].item() + if blk_idx < 0: + continue + qx = q[i] # (N, H, D) + kx = kv_cache[blk_idx] # (block_size, 1, D) + kx = kx.squeeze(1) # (block_size, D) + k_offsets = torch.arange(br * block_size, (br + 1) * block_size, device=q.device) + + mask = (k_offsets[None, :] < ctx_len) & (k_offsets[None, :] <= q_offsets[:, None]) # (N, block_size) + s = torch.where(mask[None, :, :], + torch.einsum('nhd,ld->hnl', qx, kx), + torch.full((H, N, block_size), float("-inf"), device=q.device)) + s = s.relu() * weight_slice[..., None] + logits_slice = s.sum(dim=0) # (N, block_size) + + mask_block = (k_offsets[None, :] <= q_offsets[:, None]) + logits[i * N:(i + 1) * N, br * block_size:(br + 1) * block_size] = \ + torch.where(mask_block, logits_slice, float("-inf")) + + return logits + diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index 922cd2974..8e8a5e7ee 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -3,6 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple +from sglang.srt.layers.attention.nsa.fallback_fp8 import fallback_fp8_mqa_logits, fallback_fp8_paged_mqa_logits import torch import torch.nn.functional as F from einops import rearrange @@ -14,7 +15,7 @@ from sglang.srt.utils import add_prefix, is_npu if not is_npu(): from sglang.srt.layers.attention.nsa.tilelang_kernel import act_quant - import deep_gemm + #import deep_gemm from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM, NSA_USE_REAL_INDEXER from sglang.srt.layers.dp_attention import get_attention_tp_group @@ -27,14 +28,14 @@ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import add_prefix, align, is_cuda -try: - import deep_gemm_v32 -except ImportError as e: - print("Error when importing deep_gemm_v32, try deep_gemm") - try: - import deep_gemm as deep_gemm_v32 - except ImportError as e: - print("Error when importing deep_gemm, skip") +# try: +# import deep_gemm_v32 +# except ImportError as e: +# print("Error when importing deep_gemm_v32, try deep_gemm") +# try: +# import deep_gemm as deep_gemm_v32 +# except ImportError as e: +# print("Error when importing deep_gemm, skip") if TYPE_CHECKING: @@ -81,16 +82,47 @@ class BaseIndexerMetadata(ABC): Don't assume it is the topk indices of the input logits. """ +def hadamard_transform_pytorch(x: torch.Tensor, scale: float) -> torch.Tensor: + """ + A native PyTorch implementation of the Fast Hadamard Transform that mimics + the behavior of the custom CUDA kernel's call signature. + + Args: + x (torch.Tensor): Input tensor of shape (*, N), where N is a power of 2. + scale (float): The normalization factor to multiply the result by. + + Returns: + torch.Tensor: The Hadamard transformed tensor. + """ + # Base case for recursion + if x.shape[-1] == 1: + return x + + # Split the tensor into two halves + half_size = x.shape[-1] // 2 + a = x[..., :half_size] + b = x[..., half_size:] + + # Recursive calls + a_transformed = hadamard_transform_pytorch(a, scale=1.0) # No scaling in intermediate steps + b_transformed = hadamard_transform_pytorch(b, scale=1.0) # No scaling in intermediate steps + + # Combine the results + combined = torch.cat([a_transformed + b_transformed, a_transformed - b_transformed], dim=-1) + + # Apply the scale only at the final step + return combined * scale + def rotate_activation(x: torch.Tensor) -> torch.Tensor: assert x.dtype == torch.bfloat16 - from fast_hadamard_transform import hadamard_transform + #from fast_hadamard_transform import hadamard_transform hidden_size = x.size(-1) assert ( hidden_size & (hidden_size - 1) ) == 0, "Hidden size must be a power of 2 for Hadamard transform." - return hadamard_transform(x, scale=hidden_size**-0.5) + return hadamard_transform_pytorch(x, scale=hidden_size**-0.5) class V32LayerNorm(nn.Module): @@ -140,7 +172,7 @@ class Indexer(CustomOp): self.layer_id = layer_id self.alt_stream = alt_stream if not is_npu(): - self.sm_count = deep_gemm.get_num_sms() + self.sm_count = torch.cuda.get_device_properties(0).multi_processor_count self.half_device_sm_count = align(self.sm_count // 2, 8) self.wq_b = ReplicatedLinear( @@ -273,9 +305,7 @@ class Indexer(CustomOp): k_rope, _ = torch.split( key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1 ) - q_rope, k_rope = self.rotary_emb(positions, q_rope, k_rope) - query[..., : self.rope_head_dim] = q_rope key[..., : self.rope_head_dim] = k_rope @@ -323,9 +353,9 @@ class Indexer(CustomOp): blocksize = page_size seqlens_32 = metadata.get_seqlens_int32() # NOTE(dark): 132 is SM count on H200/B200, not magic number - schedule_metadata = deep_gemm_v32.get_paged_mqa_logits_metadata( - seqlens_32, blocksize, self.sm_count - ) + # schedule_metadata = deep_gemm_v32.get_paged_mqa_logits_metadata( + # seqlens_32, blocksize, self.sm_count + # ) assert len(q_fp8.shape) == 3 q_fp8 = q_fp8.unsqueeze(1) # the next_n dim is 1 now @@ -339,15 +369,13 @@ class Indexer(CustomOp): assert len(weights.shape) == 3 weights = weights.squeeze(2) - logits = deep_gemm_v32.fp8_paged_mqa_logits( + logits = fallback_fp8_paged_mqa_logits( q_fp8, kv_cache_fp8, weights, seqlens_32, block_tables, - schedule_metadata, max_seq_len, - clean_logits=False, ) # NOTE(dark): logits should be cleaned in topk_transform @@ -408,13 +436,12 @@ class Indexer(CustomOp): seq_lens_expanded = metadata.get_seqlens_expanded() ke = ks + seq_lens_expanded - logits = deep_gemm_v32.fp8_mqa_logits( + logits = fallback_fp8_mqa_logits( q_fp8, - kv_fp8, + k_fp8, weights, ks, - ke, - clean_logits=False, + ke ) assert logits.shape[0] == len(seq_lens_expanded) diff --git a/python/sglang/srt/layers/attention/nsa/tilelang_kernel.py b/python/sglang/srt/layers/attention/nsa/tilelang_kernel.py index d2f271e17..afab3c08b 100644 --- a/python/sglang/srt/layers/attention/nsa/tilelang_kernel.py +++ b/python/sglang/srt/layers/attention/nsa/tilelang_kernel.py @@ -1,22 +1,22 @@ from typing import Optional, Tuple -import tilelang -import tilelang.language as T +# import tilelang +# import tilelang.language as T import torch -tilelang.set_log_level("WARNING") +# tilelang.set_log_level("WARNING") -pass_configs = { - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True, -} +# pass_configs = { +# tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, +# tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, +# tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True, +# } BF16 = "bfloat16" FP8 = "float8_e4m3" FP32 = "float32" - +''' def fast_log2_ceil(x): bits_x = T.reinterpret("uint32", x) exp_x = (bits_x >> 23) & 0xFF @@ -32,7 +32,6 @@ def fast_pow2(x): def fast_round_scale(amax, fp8_max_inv): return fast_pow2(fast_log2_ceil(amax * fp8_max_inv)) - @tilelang.jit(pass_configs=pass_configs) def act_quant_kernel( N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False @@ -83,7 +82,6 @@ def act_quant_kernel( return act_quant_kernel_ - def act_quant( x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -753,7 +751,6 @@ def sparse_attention_fwd_kernel_v2( return main - def tilelang_sparse_fwd( q: torch.Tensor, kv: torch.Tensor, @@ -772,3 +769,45 @@ def tilelang_sparse_fwd( num_heads, d_v, tail_dim, topk, sm_scale=sm_scale ) return kernel(q.unsqueeze(0), kv.unsqueeze(0), indices.unsqueeze(0)) # type: ignore +''' +def act_quant( + x: torch.Tensor, + block_size: int = 128, + scale_fmt: Optional[str] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + PyTorch fallback for act_quant + Block-wise FP8 E4M3 quantization + """ + if not x.is_contiguous(): + x = x.contiguous() + + N = x.size(-1) + assert N % block_size == 0, f"Last dim {N} must be divisible by block_size={block_size}" + + # Reshape to blocks + x_2d = x.view(-1, N) + x_blocks = x_2d.view(-1, block_size) + + # Compute absmax per block + amax = x_blocks.abs().amax(dim=1, keepdim=True).clamp(min=1e-4) + + # FP8 E4M3 max value is ~448 + fp8_max = 448.0 + scale = amax / fp8_max + + if scale_fmt is not None: + # Simulate rounded scale (power-of-2 rounding) + scale = torch.round(scale * 256) / 256 + + # Quantize and clamp + y_blocks = torch.clamp(torch.round(x_blocks / scale), -fp8_max, fp8_max) + + # Convert to FP8 + q = y_blocks.view_as(x_2d).to(torch.float8_e4m3fn) + + # Reshape scale + s = scale.view(x_2d.size(0), N // block_size).to(torch.float32) + s = s.view(*x.shape[:-1], N // block_size) + + return q.view_as(x), s diff --git a/python/sglang/srt/layers/attention/nsa/transform_index.py b/python/sglang/srt/layers/attention/nsa/transform_index.py index 442dd113d..8bf653750 100644 --- a/python/sglang/srt/layers/attention/nsa/transform_index.py +++ b/python/sglang/srt/layers/attention/nsa/transform_index.py @@ -105,7 +105,7 @@ def transform_index_page_table_decode_ref( torch.gather( page_table, dim=1, - index=topk_indices.clamp(min=0), + index=topk_indices.clamp(min=0).long(), out=result, ) result[topk_indices < 0] = -1 diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index 54e62c94d..9a5b4b208 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -10,7 +10,6 @@ from typing import ( Tuple, TypeAlias, Union, - override, ) import torch @@ -101,19 +100,15 @@ class NSAMetadata: class NSAIndexerMetadata(BaseIndexerMetadata): attn_metadata: NSAMetadata - @override def get_seqlens_int32(self) -> torch.Tensor: return self.attn_metadata.cache_seqlens_int32 - @override def get_page_table_64(self) -> torch.Tensor: return self.attn_metadata.real_page_table - @override def get_seqlens_expanded(self) -> torch.Tensor: return self.attn_metadata.nsa_seqlens_expanded - @override def topk_transform( self, logits: torch.Tensor, @@ -524,21 +519,25 @@ class NativeSparseAttnBackend(AttentionBackend): extend_lens_cpu=metadata.nsa_extend_seq_lens_list, page_size=1, ) - if NSA_PREFILL_IMPL == "tilelang": - from sglang.srt.layers.attention.nsa.tilelang_kernel import ( - tilelang_sparse_fwd, - ) + # if NSA_PREFILL_IMPL == "tilelang": + # from sglang.srt.layers.attention.nsa.tilelang_kernel import ( + # tilelang_sparse_fwd, + # ) - if q_rope is not None: - q_all = torch.cat([q_nope, q_rope], dim=-1) - return self._forward_tilelang( - q_all=q_all, - kv_cache=kv_cache, - page_table_1=page_table_1, - sm_scale=layer.scaling, - v_head_dim=layer.v_head_dim, - ) - elif NSA_PREFILL_IMPL == "flashmla_prefill": + # if q_rope is not None: + # q_all = torch.cat([q_nope, q_rope], dim=-1) + # return self._forward_tilelang( + # q_all=q_all, + # kv_cache=kv_cache, + # page_table_1=page_table_1, + # sm_scale=layer.scaling, + # v_head_dim=layer.v_head_dim, + # ) + # elif NSA_PREFILL_IMPL == "flashmla_prefill": + + + # Skip tilelang dependencies + if NSA_PREFILL_IMPL == "tilelang" or NSA_PREFILL_IMPL == "flashmla_prefill": if q_rope is not None: q_all = torch.cat([q_nope, q_rope], dim=-1) return self._forward_flashmla_prefill( @@ -733,9 +732,9 @@ class NativeSparseAttnBackend(AttentionBackend): page_table_1: torch.Tensor, sm_scale: float, ) -> torch.Tensor: - from flash_mla import flash_mla_sparse_fwd - - o, _, _ = flash_mla_sparse_fwd( + #from flash_mla import flash_mla_sparse_fwd + from sglang.srt.layers.attention.native_mla import native_mla_sparse_fwd + _, _, o = native_mla_sparse_fwd( q=q_all, kv=kv_cache, indices=page_table_1.unsqueeze(1), @@ -756,8 +755,8 @@ class NativeSparseAttnBackend(AttentionBackend): topk_indices, block_table, ) -> torch.Tensor: - from flash_mla import flash_mla_with_kvcache - + #from flash_mla import flash_mla_with_kvcache + from sglang.srt.layers.attention.native_mla import native_mla_with_kvcache cache_seqlens = metadata.nsa_cache_seqlens_int32 # TODO the 2nd dim is seq_len_q, need to be >1 when MTP @@ -769,7 +768,7 @@ class NativeSparseAttnBackend(AttentionBackend): # inefficiently quantize the whole cache kv_cache = quantize_k_cache(kv_cache) - o, _ = flash_mla_with_kvcache( + o, _ = native_mla_with_kvcache( q=q_all, k_cache=kv_cache, cache_seqlens=cache_seqlens, diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 87a392d55..bcede84d8 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -136,21 +136,21 @@ class RMSNorm(CustomOp): # NOTE: Remove this if aiter kernel supports discontinuous input x = x.contiguous() if residual is not None: - if _vllm_version < Version("0.9"): - fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon) - return x, residual - else: - residual_out = torch.empty_like(x) - output = torch.empty_like(x) - fused_add_rms_norm( - output, - x, - residual_out, - residual, - self.weight.data, - self.variance_epsilon, - ) - return output, residual_out + #if _vllm_version < Version("0.9"): + fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon) + return x, residual + # else: + # residual_out = torch.empty_like(x) + # output = torch.empty_like(x) + # fused_add_rms_norm( + # output, + # x, + # residual_out, + # residual, + # self.weight.data, + # self.variance_epsilon, + # ) + # return output, residual_out out = torch.empty_like(x) rms_norm(out, x, self.weight.data, self.variance_epsilon) return out diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index f0e9e5a7b..e743a3c76 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -765,7 +765,10 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj query_rot = query_rot * cos + rotate_fn(query_rot) * sin - key_rot = key_rot * cos + rotate_fn(key_rot) * sin + cos_for_key = cos[:, 0, ...] + sin_for_key = sin[:, 0, ...] + key_rot = key_rot * cos_for_key + rotate_fn(key_rot) * sin_for_key + #key_rot = key_rot * cos + rotate_fn(key_rot) * sin if self.rotary_dim < self.head_size: query = torch.cat((query_rot, query_pass), dim=-1)