misc: cache is_hopper_arch (#6799)
This commit is contained in:
@@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
import functools
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -41,6 +42,7 @@ def _to_tensor_scalar_tuple(x):
|
|||||||
return (None, x)
|
return (None, x)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=1)
|
||||||
def is_hopper_arch() -> bool:
|
def is_hopper_arch() -> bool:
|
||||||
# Hopper arch's compute capability == 9.0
|
# Hopper arch's compute capability == 9.0
|
||||||
device = torch.cuda.current_device()
|
device = torch.cuda.current_device()
|
||||||
|
|||||||
Reference in New Issue
Block a user