[gpt-oss] Add gpt-oss bf16 support

This commit is contained in:
2025-08-13 21:25:57 +08:00
parent 5d2e7edf78
commit 17ea2ec6aa
1232 changed files with 777 additions and 36 deletions

View File

@@ -28,6 +28,7 @@ def kernel_paged_attention_2d(
query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
sink_ptr, # [num_query_heads]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
@@ -59,6 +60,7 @@ def kernel_paged_attention_2d(
stride_v_cache_3: tl.int64, # int
filter_by_query_len: tl.constexpr, # bool
query_start_len_ptr, # [num_seqs+1]
USE_SINKS: tl.constexpr, # bool
):
seq_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
@@ -95,7 +97,18 @@ def kernel_paged_attention_2d(
block_table_offset = seq_idx * block_table_stride
M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
if not USE_SINKS:
M = tl.full([num_queries_per_kv_padded],
float("-inf"),
dtype=tl.float32)
else:
M = tl.load(
sink_ptr + query_head_idx,
mask=head_mask,
other=float("-inf"),
).to(dtype=tl.float32)
# M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED],
dtype=tl.float32)
@@ -223,6 +236,8 @@ def chunked_prefill_paged_decode(
alibi_slopes=None,
sliding_window=None,
sm_scale=None,
# Optional tensor for sinks
sinks=None,
):
if sm_scale is None:
@@ -253,6 +268,7 @@ def chunked_prefill_paged_decode(
sliding_window=sliding_window,
sm_scale=sm_scale,
skip_decode=True,
sinks=sinks,
)
block_size = value_cache.shape[3]
@@ -285,7 +301,7 @@ def chunked_prefill_paged_decode(
block_size,
num_queries_per_kv,
max_seq_len, sliding_window,
kv_cache_dtype, alibi_slopes)
kv_cache_dtype, alibi_slopes, sinks,)
if use_custom:
_PARTITION_SIZE_ROCM = 256
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
@@ -334,6 +350,7 @@ def chunked_prefill_paged_decode(
query_ptr=query,
key_cache_ptr=key_cache,
value_cache_ptr=value_cache,
sink_ptr=sinks,
block_tables_ptr=block_table,
seq_lens_ptr=seq_lens,
alibi_slopes_ptr=alibi_slopes,
@@ -365,4 +382,5 @@ def chunked_prefill_paged_decode(
stride_v_cache_3=value_cache.stride(3),
filter_by_query_len=True,
query_start_len_ptr=query_start_loc,
USE_SINKS=sinks is not None,
)

View File

@@ -34,6 +34,7 @@ def kernel_unified_attention_2d(
query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
sink_ptr, # [num_query_heads]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
@@ -53,6 +54,7 @@ def kernel_unified_attention_2d(
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
USE_SOFTCAP: tl.constexpr, # bool
USE_SINKS: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
@@ -119,7 +121,16 @@ def kernel_unified_attention_2d(
block_table_offset = seq_idx * block_table_stride
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
if not USE_SINKS:
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
else:
M = tl.load(
sink_ptr + query_offset_1,
mask=query_mask_1,
other=float("-inf"),
).to(dtype=tl.float32)
# M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
@@ -260,6 +271,8 @@ def unified_attention(
k_descale,
v_descale,
alibi_slopes=None,
# Optional tensor for sinks
sinks=None,
):
assert causal, "Only causal attention is supported"
assert q_descale is None, "Q scales not supported"
@@ -267,6 +280,10 @@ def unified_attention(
block_size = v.shape[1]
assert q.element_size() >= 2 or block_size >= 32, \
"Block size must be at least 32 for fp8"
if sinks is not None:
assert sinks.shape[0] == q.shape[1], \
"Sinks must be num_query_heads size"
use_alibi_slopes = alibi_slopes is not None
@@ -299,6 +316,7 @@ def unified_attention(
query_ptr=q,
key_cache_ptr=k,
value_cache_ptr=v,
sink_ptr=sinks,
block_tables_ptr=block_table,
seq_lens_ptr=seqused_k,
alibi_slopes_ptr=alibi_slopes,
@@ -318,6 +336,7 @@ def unified_attention(
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
USE_SOFTCAP=(softcap > 0),
USE_SINKS=(sinks is not None),
SLIDING_WINDOW=(1 + window_size[0]),
stride_k_cache_0=k.stride(0),
stride_k_cache_1=k.stride(1),

Some files were not shown because too many files have changed in this diff Show More