[gpt-oss] Add gpt-oss bf16 support

This commit is contained in:
2025-08-13 21:25:57 +08:00
parent 5d2e7edf78
commit 8ba49a7723
9 changed files with 777 additions and 36 deletions

View File

@@ -90,6 +90,7 @@ class TritonAttentionImpl(AttentionImpl):
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
use_irope: bool = False,
sinks: Optional[torch.Tensor] = None,
) -> None:
if blocksparse_params is not None:
raise ValueError(
@@ -132,6 +133,13 @@ class TritonAttentionImpl(AttentionImpl):
self.fp8_dtype = current_platform.fp8_dtype()
self.force_prefill_decode_attn = \
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
self.sinks = sinks
if sinks is not None:
assert sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}.")
def forward(
self,
@@ -257,7 +265,8 @@ class TritonAttentionImpl(AttentionImpl):
v_scale=layer._v_scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0],
sm_scale=self.scale)
sm_scale=self.scale,
sinks=self.sinks)
else:
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
@@ -280,6 +289,7 @@ class TritonAttentionImpl(AttentionImpl):
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
sinks=self.sinks,
)
return output