misc: cache is_hopper_arch (#6799)

This commit is contained in:
Wenxuan Tan
2025-06-01 17:28:31 -05:00
committed by GitHub
parent 1da8d23051
commit c429919def

View File

@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
import functools
from typing import Dict, Tuple
import torch
@@ -41,6 +42,7 @@ def _to_tensor_scalar_tuple(x):
return (None, x)
@functools.lru_cache(maxsize=1)
def is_hopper_arch() -> bool:
# Hopper arch's compute capability == 9.0
device = torch.cuda.current_device()