From 51d32b6d492021a748ef908d044ff28fd219c23e Mon Sep 17 00:00:00 2001 From: maxiao1 <5788-maxiao1@users.noreply.192.168.54.209> Date: Mon, 15 Sep 2025 06:01:38 +0000 Subject: [PATCH] fixed buged: sgl_kernel object has no attribute 'fwd' --- sgl-kernel/python/sgl_kernel/flash_attn.py | 226 ++------------------- 1 file changed, 17 insertions(+), 209 deletions(-) diff --git a/sgl-kernel/python/sgl_kernel/flash_attn.py b/sgl-kernel/python/sgl_kernel/flash_attn.py index ada7b6d07..4ec054385 100644 --- a/sgl-kernel/python/sgl_kernel/flash_attn.py +++ b/sgl-kernel/python/sgl_kernel/flash_attn.py @@ -35,199 +35,7 @@ def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x -# def flash_attn_with_kvcache( -# q, -# k_cache, -# v_cache, -# k=None, -# v=None, -# qv=None, -# rotary_cos=None, -# rotary_sin=None, -# cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, -# cache_batch_idx: Optional[torch.Tensor] = None, -# cache_leftpad: Optional[torch.Tensor] = None, -# page_table: Optional[torch.Tensor] = None, -# cu_seqlens_q: Optional[torch.Tensor] = None, -# cu_seqlens_k_new: Optional[torch.Tensor] = None, -# max_seqlen_q: Optional[int] = None, -# rotary_seqlens: Optional[torch.Tensor] = None, -# q_descale: Optional[torch.Tensor] = None, -# k_descale: Optional[torch.Tensor] = None, -# v_descale: Optional[torch.Tensor] = None, -# softmax_scale=None, -# causal=False, -# window_size=(-1, -1), # -1 means infinite context window -# softcap=0.0, # 0.0 means deactivated -# rotary_interleaved=True, -# scheduler_metadata=None, -# num_splits=0, # Can be tuned for speed -# pack_gqa=None, # Can be tuned for speed -# sm_margin=0, # Can be tuned if some SMs are used for communication -# return_softmax_lse=False, -# sinks=None, -# ver=3, -# ): -# """ -# If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from -# k and v. This is useful for incremental decoding: you can pass in the cached keys/values from -# the previous step, and update them with the new keys/values from the current step, and do -# attention with the updated cache, all in 1 kernel. -# If you pass in k / v, you must make sure that the cache is large enough to hold the new values. -# For example, the KV cache could be pre-allocated with the max sequence length, and you can use -# cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. - -# Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be -# rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. -# If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos -# and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. -# If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at -# indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). - -# See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. - -# Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads -# than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. -# For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head -# 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - -# If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. -# For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: -# 1 1 1 1 0 -# 1 1 1 1 1 -# If seqlen_q = 5 and seqlen_k = 2, the causal mask is: -# 0 0 -# 0 0 -# 0 0 -# 1 0 -# 1 1 -# If the row of the mask is all zero, the output will be zero. - -# If window_size != (-1, -1), implements sliding window local attention. Query at position i -# will only attend to keys between -# [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - -# Note: Does not support backward pass. - -# Arguments: -# q: (batch_size, seqlen, nheads, headdim) -# k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table, -# or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache) -# page_block_size must be a multiple of 256. -# v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table, -# or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache) -# k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate -# k with k_cache, starting at the indices specified by cache_seqlens. -# v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k. -# qv [optional]: (batch_size, seqlen, nheads, headdim_v) -# rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding -# to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. -# rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. -# cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the -# KV cache. -# cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. -# If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. -# If the indices are not distinct, and k and v are provided, the values updated in the cache -# might come from any of the duplicate indices. -# cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0. -# page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. -# softmax_scale: float. The scaling of QK^T before applying softmax. -# Default to 1 / sqrt(headdim). -# causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). -# window_size: (left, right). If not (-1, -1), implements sliding window local attention. -# softcap: float. Anything > 0 activates softcapping attention. -# rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. -# If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, -# rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 -# (i.e. GPT-NeoX style). -# num_splits: int. If > 1, split the key/value into this many chunks along the sequence. -# If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic -# to automatically determine the number of splits. -# Don't change this unless you know what you are doing. -# return_softmax_lse: bool. Whether to return the logsumexp of the attention scores. - -# Return: -# out: (batch_size, seqlen, nheads, headdim). -# softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The -# logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax -# normalization factor). -# """ -# if ver == 4: -# raise NotImplementedError("haven't implemented flash_attn_with_kvcache for fa4") - -# assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" -# assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" -# if softmax_scale is None: -# softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** ( -# -0.5 -# ) -# if cache_seqlens is not None and isinstance(cache_seqlens, int): -# cache_seqlens = torch.full( -# (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device -# ) -# cache_seqlens = maybe_contiguous(cache_seqlens) - -# q, k_cache, k, v = [maybe_contiguous(x) for x in (q, k_cache, k, v)] -# v_cache = ( -# v_cache.contiguous() -# if v_cache.stride(-1) != 1 and v_cache.stride(-3) != 1 -# else v_cache -# ) -# cu_seqlens_q, cu_seqlens_k_new = [ -# maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k_new) -# ] -# page_table, cache_batch_idx, cache_leftpad = [ -# maybe_contiguous(x) for x in (page_table, cache_batch_idx, cache_leftpad) -# ] -# rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] -# rotary_seqlens = maybe_contiguous(rotary_seqlens) - -# if hasattr(torch.version, 'hip') and torch.version.hip is not None: -# # HIP环境回退 -# from flash_attn import flash_attn_with_kvcache as fa_with_kv -# out, softmax_lse, *rest = fa_with_kv( -# q, k, v, k_cache, v_cache, cache_seqlens, cache_batch_idx, -# block_tables, softmax_scale, causal, alibi_slopes, out -# ) -# else: -# out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default( -# q, -# k_cache, -# v_cache, -# k, -# v, -# qv, -# None, # out -# cu_seqlens_q, -# None, # cu_seqlens_k -# cu_seqlens_k_new, -# None, # seqused_q -# cache_seqlens, -# max_seqlen_q, -# None, # max_seqlen_k -# page_table, -# cache_batch_idx, -# cache_leftpad, -# rotary_cos, -# rotary_sin, -# rotary_seqlens, -# q_descale, -# k_descale, -# v_descale, -# softmax_scale, -# causal, -# window_size[0], -# window_size[1], -# softcap, -# rotary_interleaved, -# scheduler_metadata, -# num_splits, -# pack_gqa, -# sm_margin, -# sinks, -# ) -# return (out, softmax_lse) if return_softmax_lse else out def flash_attn_with_kvcache( q, k_cache, @@ -265,27 +73,27 @@ def flash_attn_with_kvcache( raise NotImplementedError("haven't implemented flash_attn_with_kvcache for fa4") # HIP环境检测和回退 - # if hasattr(torch.version, 'hip') and torch.version.hip is not None: - # # 简单PyTorch回退,处理实际的张量形状 - # # q: [1, 4, 256], k_cache: [411528, 1, 1, 256], v_cache: [411528, 1, 1, 256] + if hasattr(torch.version, 'hip') and torch.version.hip is not None: + # 简单PyTorch回退,处理实际的张量形状 + # q: [1, 4, 256], k_cache: [411528, 1, 1, 256], v_cache: [411528, 1, 1, 256] - # if softmax_scale is None: - # softmax_scale = (q.shape[-1]) ** (-0.5) + if softmax_scale is None: + softmax_scale = (q.shape[-1]) ** (-0.5) - # # 重塑以匹配attention计算 - # q_reshaped = q.unsqueeze(1) # [1, 1, 4, 256] - # k_reshaped = k_cache.squeeze(1).squeeze(1) # [411528, 256] - # v_reshaped = v_cache.squeeze(1).squeeze(1) # [411528, 256] + # 重塑以匹配attention计算 + q_reshaped = q.unsqueeze(1) # [1, 1, 4, 256] + k_reshaped = k_cache.squeeze(1).squeeze(1) # [411528, 256] + v_reshaped = v_cache.squeeze(1).squeeze(1) # [411528, 256] - # # 简单的点积attention - # scores = torch.matmul(q, k_reshaped.T) * softmax_scale # [1, 4, 411528] - # attn_weights = torch.softmax(scores, dim=-1) - # out = torch.matmul(attn_weights, v_reshaped) # [1, 4, 256] + # 简单的点积attention + scores = torch.matmul(q, k_reshaped.T) * softmax_scale # [1, 4, 411528] + attn_weights = torch.softmax(scores, dim=-1) + out = torch.matmul(attn_weights, v_reshaped) # [1, 4, 256] - # if return_softmax_lse: - # softmax_lse = torch.zeros(1, 4, 1, device=q.device) - # return out, softmax_lse - # return out + if return_softmax_lse: + softmax_lse = torch.zeros(1, 4, 1, device=q.device) + return out, softmax_lse + return out # 原始sgl_kernel实现 assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"