[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
@@ -126,7 +126,8 @@ def use_rocm_custom_paged_attention(
|
||||
max_seq_len: int,
|
||||
sliding_window: int,
|
||||
kv_cache_dtype: str,
|
||||
alibi_slopes: Optional[torch.Tensor] = None) -> bool:
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
sinks: Optional[torch.Tensor] = None) -> bool:
|
||||
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||
@@ -143,7 +144,7 @@ def use_rocm_custom_paged_attention(
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
||||
and max_seq_len <= 32768 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
||||
and envs.VLLM_ROCM_USE_AITER))
|
||||
and envs.VLLM_ROCM_USE_AITER) and sinks is None)
|
||||
|
||||
else:
|
||||
return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
|
||||
@@ -153,7 +154,7 @@ def use_rocm_custom_paged_attention(
|
||||
and (gqa_ratio >= 3 and gqa_ratio <= 16)
|
||||
and max_seq_len <= 32768 and alibi_slopes is None
|
||||
and kv_cache_dtype == "auto"
|
||||
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN and sinks is None)
|
||||
|
||||
|
||||
class RocmPlatform(Platform):
|
||||
|
||||
Reference in New Issue
Block a user