diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 469e4fde3..10d242ebe 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -686,7 +686,7 @@ class TritonAttnBackend(AttentionBackend): layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache=True, - sk=None, + sinks=None, ): # TODO: reuse the buffer across layers if layer.qk_head_dim != layer.v_head_dim: @@ -731,7 +731,7 @@ class TritonAttnBackend(AttentionBackend): layer.scaling, layer.logit_cap, sliding_window_size=sliding_window_size, - sk=sk, + sinks=sinks, ) return o @@ -743,7 +743,7 @@ class TritonAttnBackend(AttentionBackend): layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache=True, - sk=None, + sinks=None, ): # During torch.compile, there is a bug in rotary_emb that causes the # output value to have a 3D tensor shape. This reshapes the output correctly. @@ -780,7 +780,7 @@ class TritonAttnBackend(AttentionBackend): self.max_kv_splits, layer.scaling, layer.logit_cap, - sk=sk, + sinks=sinks, ) return o diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 5e345586e..014eadab7 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -495,7 +495,7 @@ def _fwd_kernel_stage2( O, kv_indptr, num_kv_splits, - sk_ptr, + sink_ptr, stride_mid_ob, stride_mid_oh, stride_mid_os, @@ -505,7 +505,7 @@ def _fwd_kernel_stage2( MIN_BLOCK_KV: tl.constexpr, BLOCK_DV: tl.constexpr, Lv: tl.constexpr, - HAS_SK: tl.constexpr, + HAS_SINK: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -547,9 +547,9 @@ def _fwd_kernel_stage2( e_sum = e_sum * old_scale + exp_logic e_max = n_e_max - if HAS_SK: - cur_sk = tl.load(sk_ptr + cur_head) - e_sum += tl.exp(cur_sk - e_max) + if HAS_SINK: + cur_sink = tl.load(sink_ptr + cur_head) + e_sum += tl.exp(cur_sink - e_max) tl.store( O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, @@ -567,14 +567,14 @@ def _decode_softmax_reducev_fwd( kv_indptr, num_kv_splits, max_kv_splits, - sk=None, + sinks=None, ): batch, head_num = q.shape[0], q.shape[1] Lv = v_buffer.shape[-1] BLOCK_DV = triton.next_power_of_2(Lv) MAX_KV_SPLITS = max_kv_splits - HAS_SK = sk is not None + HAS_SINK = sinks is not None extra_kargs = {} if _is_hip: @@ -589,7 +589,7 @@ def _decode_softmax_reducev_fwd( o, kv_indptr, num_kv_splits, - sk, + sinks, logits.stride(0), logits.stride(1), logits.stride(2), @@ -599,7 +599,7 @@ def _decode_softmax_reducev_fwd( MIN_BLOCK_KV=_MIN_BLOCK_KV, BLOCK_DV=BLOCK_DV, Lv=Lv, - HAS_SK=HAS_SK, + HAS_SINK=HAS_SINK, num_warps=4, num_stages=2, **extra_kargs, @@ -619,7 +619,7 @@ def decode_attention_fwd_normal( max_kv_splits, sm_scale, logit_cap=0.0, - sk=None, + sinks=None, ): _decode_att_m_fwd( q, @@ -643,7 +643,7 @@ def decode_attention_fwd_normal( kv_indptr, num_kv_splits, max_kv_splits, - sk, + sinks, ) @@ -660,7 +660,7 @@ def decode_attention_fwd_grouped( max_kv_splits, sm_scale, logit_cap=0.0, - sk=None, + sinks=None, ): _decode_grouped_att_m_fwd( q, @@ -684,7 +684,7 @@ def decode_attention_fwd_grouped( kv_indptr, num_kv_splits, max_kv_splits, - sk, + sinks, ) @@ -701,7 +701,7 @@ def decode_attention_fwd( max_kv_splits, sm_scale, logit_cap=0.0, - sk=None, + sinks=None, ): assert max_kv_splits == attn_logits.shape[2] assert q.shape[0] <= kv_indptr.shape[0] - 1 @@ -724,7 +724,7 @@ def decode_attention_fwd( max_kv_splits, sm_scale, logit_cap=logit_cap, - sk=sk, + sinks=sinks, ) else: # GQA/MQA/MLA @@ -741,5 +741,5 @@ def decode_attention_fwd( max_kv_splits, sm_scale, logit_cap=logit_cap, - sk=sk, + sinks=sinks, ) diff --git a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py index e1b707f39..89f816a27 100644 --- a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -51,7 +51,7 @@ def _fwd_kernel( kv_indices, mask_ptr, mask_indptr, - sk_ptr, + sink_ptr, sm_scale, kv_group_num, stride_qbs, @@ -79,7 +79,7 @@ def _fwd_kernel( IS_CAUSAL: tl.constexpr, SKIP_PREFIX_CUSTOM_MASK: tl.constexpr, STORE_TRANSPOSE: tl.constexpr, - HAS_SK: tl.constexpr, + HAS_SINK: tl.constexpr, ): cur_seq = tl.program_id(0) cur_head = tl.program_id(1) @@ -302,9 +302,9 @@ def _fwd_kernel( e_max = n_e_max - if HAS_SK: - cur_sk = tl.load(sk_ptr + cur_head) - deno += tl.exp(cur_sk - e_max) + if HAS_SINK: + cur_sink = tl.load(sink_ptr + cur_head) + deno += tl.exp(cur_sink - e_max) offs_o = ( (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) @@ -344,7 +344,7 @@ def extend_attention_fwd( logit_cap=0.0, skip_prefix_custom_mask=True, sliding_window_size=-1, - sk=None, + sinks=None, ): """ q_extend, k_extend, v_extend, o_extend: contiguous tensors @@ -410,7 +410,7 @@ def extend_attention_fwd( # Skip custom mask for prefix part SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask - HAS_SK = sk is not None + HAS_SINK = sinks is not None grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M)) num_stages = 1 @@ -431,7 +431,7 @@ def extend_attention_fwd( kv_indices, custom_mask, mask_indptr, - sk, + sinks, sm_scale, kv_group_num, q_extend.stride(0), @@ -458,7 +458,7 @@ def extend_attention_fwd( USE_CUSTOM_MASK=USE_CUSTOM_MASK, IS_CAUSAL=is_causal, SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK, - HAS_SK=HAS_SK, + HAS_SINK=HAS_SINK, STORE_TRANSPOSE=_is_hip, num_warps=num_warps, num_stages=num_stages, diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 4ca9c40c5..58b68fb38 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -301,7 +301,7 @@ class GptOssAttention(nn.Module): hidden_states, forward_batch, inner_state = intermediate_state if inner_state is None: return hidden_states - attn_output = self.attn(*inner_state, sk=self.sinks) + attn_output = self.attn(*inner_state, sinks=self.sinks) output, _ = self.o_proj(attn_output) return output