[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
@@ -28,6 +28,7 @@ def kernel_paged_attention_2d(
|
||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
|
||||
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
|
||||
sink_ptr, # [num_query_heads]
|
||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
alibi_slopes_ptr, # [num_query_heads]
|
||||
@@ -59,6 +60,7 @@ def kernel_paged_attention_2d(
|
||||
stride_v_cache_3: tl.int64, # int
|
||||
filter_by_query_len: tl.constexpr, # bool
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
USE_SINKS: tl.constexpr, # bool
|
||||
):
|
||||
seq_idx = tl.program_id(0)
|
||||
kv_head_idx = tl.program_id(1)
|
||||
@@ -95,7 +97,18 @@ def kernel_paged_attention_2d(
|
||||
|
||||
block_table_offset = seq_idx * block_table_stride
|
||||
|
||||
M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
|
||||
if not USE_SINKS:
|
||||
M = tl.full([num_queries_per_kv_padded],
|
||||
float("-inf"),
|
||||
dtype=tl.float32)
|
||||
else:
|
||||
M = tl.load(
|
||||
sink_ptr + query_head_idx,
|
||||
mask=head_mask,
|
||||
other=float("-inf"),
|
||||
).to(dtype=tl.float32)
|
||||
# M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
|
||||
|
||||
L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED],
|
||||
dtype=tl.float32)
|
||||
@@ -223,6 +236,8 @@ def chunked_prefill_paged_decode(
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
sm_scale=None,
|
||||
# Optional tensor for sinks
|
||||
sinks=None,
|
||||
):
|
||||
|
||||
if sm_scale is None:
|
||||
@@ -253,6 +268,7 @@ def chunked_prefill_paged_decode(
|
||||
sliding_window=sliding_window,
|
||||
sm_scale=sm_scale,
|
||||
skip_decode=True,
|
||||
sinks=sinks,
|
||||
)
|
||||
|
||||
block_size = value_cache.shape[3]
|
||||
@@ -285,7 +301,7 @@ def chunked_prefill_paged_decode(
|
||||
block_size,
|
||||
num_queries_per_kv,
|
||||
max_seq_len, sliding_window,
|
||||
kv_cache_dtype, alibi_slopes)
|
||||
kv_cache_dtype, alibi_slopes, sinks,)
|
||||
if use_custom:
|
||||
_PARTITION_SIZE_ROCM = 256
|
||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
|
||||
@@ -334,6 +350,7 @@ def chunked_prefill_paged_decode(
|
||||
query_ptr=query,
|
||||
key_cache_ptr=key_cache,
|
||||
value_cache_ptr=value_cache,
|
||||
sink_ptr=sinks,
|
||||
block_tables_ptr=block_table,
|
||||
seq_lens_ptr=seq_lens,
|
||||
alibi_slopes_ptr=alibi_slopes,
|
||||
@@ -365,4 +382,5 @@ def chunked_prefill_paged_decode(
|
||||
stride_v_cache_3=value_cache.stride(3),
|
||||
filter_by_query_len=True,
|
||||
query_start_len_ptr=query_start_loc,
|
||||
USE_SINKS=sinks is not None,
|
||||
)
|
||||
@@ -34,6 +34,7 @@ def kernel_unified_attention_2d(
|
||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||
sink_ptr, # [num_query_heads]
|
||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
alibi_slopes_ptr, # [num_query_heads]
|
||||
@@ -53,6 +54,7 @@ def kernel_unified_attention_2d(
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||
USE_SOFTCAP: tl.constexpr, # bool
|
||||
USE_SINKS: tl.constexpr, # bool
|
||||
SLIDING_WINDOW: tl.constexpr, # int
|
||||
stride_k_cache_0: tl.int64, # int
|
||||
stride_k_cache_1: tl.int64, # int
|
||||
@@ -119,7 +121,16 @@ def kernel_unified_attention_2d(
|
||||
|
||||
block_table_offset = seq_idx * block_table_stride
|
||||
|
||||
if not USE_SINKS:
|
||||
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
else:
|
||||
M = tl.load(
|
||||
sink_ptr + query_offset_1,
|
||||
mask=query_mask_1,
|
||||
other=float("-inf"),
|
||||
).to(dtype=tl.float32)
|
||||
# M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
|
||||
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
|
||||
|
||||
@@ -260,6 +271,8 @@ def unified_attention(
|
||||
k_descale,
|
||||
v_descale,
|
||||
alibi_slopes=None,
|
||||
# Optional tensor for sinks
|
||||
sinks=None,
|
||||
):
|
||||
assert causal, "Only causal attention is supported"
|
||||
assert q_descale is None, "Q scales not supported"
|
||||
@@ -268,6 +281,10 @@ def unified_attention(
|
||||
assert q.element_size() >= 2 or block_size >= 32, \
|
||||
"Block size must be at least 32 for fp8"
|
||||
|
||||
if sinks is not None:
|
||||
assert sinks.shape[0] == q.shape[1], \
|
||||
"Sinks must be num_query_heads size"
|
||||
|
||||
use_alibi_slopes = alibi_slopes is not None
|
||||
|
||||
block_size = v.shape[1]
|
||||
@@ -299,6 +316,7 @@ def unified_attention(
|
||||
query_ptr=q,
|
||||
key_cache_ptr=k,
|
||||
value_cache_ptr=v,
|
||||
sink_ptr=sinks,
|
||||
block_tables_ptr=block_table,
|
||||
seq_lens_ptr=seqused_k,
|
||||
alibi_slopes_ptr=alibi_slopes,
|
||||
@@ -318,6 +336,7 @@ def unified_attention(
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||
USE_SOFTCAP=(softcap > 0),
|
||||
USE_SINKS=(sinks is not None),
|
||||
SLIDING_WINDOW=(1 + window_size[0]),
|
||||
stride_k_cache_0=k.stride(0),
|
||||
stride_k_cache_1=k.stride(1),
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user