[gpt-oss] Add gpt-oss bf16 support

This commit is contained in:
2025-08-13 21:25:57 +08:00
parent 5d2e7edf78
commit 8ba49a7723
9 changed files with 777 additions and 36 deletions

View File

@@ -28,6 +28,7 @@ def kernel_paged_attention_2d(
query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
sink_ptr, # [num_query_heads]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
@@ -59,6 +60,7 @@ def kernel_paged_attention_2d(
stride_v_cache_3: tl.int64, # int
filter_by_query_len: tl.constexpr, # bool
query_start_len_ptr, # [num_seqs+1]
USE_SINKS: tl.constexpr, # bool
):
seq_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
@@ -95,7 +97,18 @@ def kernel_paged_attention_2d(
block_table_offset = seq_idx * block_table_stride
M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
if not USE_SINKS:
M = tl.full([num_queries_per_kv_padded],
float("-inf"),
dtype=tl.float32)
else:
M = tl.load(
sink_ptr + query_head_idx,
mask=head_mask,
other=float("-inf"),
).to(dtype=tl.float32)
# M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED],
dtype=tl.float32)
@@ -223,6 +236,8 @@ def chunked_prefill_paged_decode(
alibi_slopes=None,
sliding_window=None,
sm_scale=None,
# Optional tensor for sinks
sinks=None,
):
if sm_scale is None:
@@ -253,6 +268,7 @@ def chunked_prefill_paged_decode(
sliding_window=sliding_window,
sm_scale=sm_scale,
skip_decode=True,
sinks=sinks,
)
block_size = value_cache.shape[3]
@@ -285,7 +301,7 @@ def chunked_prefill_paged_decode(
block_size,
num_queries_per_kv,
max_seq_len, sliding_window,
kv_cache_dtype, alibi_slopes)
kv_cache_dtype, alibi_slopes, sinks,)
if use_custom:
_PARTITION_SIZE_ROCM = 256
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
@@ -334,6 +350,7 @@ def chunked_prefill_paged_decode(
query_ptr=query,
key_cache_ptr=key_cache,
value_cache_ptr=value_cache,
sink_ptr=sinks,
block_tables_ptr=block_table,
seq_lens_ptr=seq_lens,
alibi_slopes_ptr=alibi_slopes,
@@ -365,4 +382,5 @@ def chunked_prefill_paged_decode(
stride_v_cache_3=value_cache.stride(3),
filter_by_query_len=True,
query_start_len_ptr=query_start_loc,
USE_SINKS=sinks is not None,
)