misc: cache is_hopper_arch (#6799)
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import functools
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
@@ -41,6 +42,7 @@ def _to_tensor_scalar_tuple(x):
|
||||
return (None, x)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def is_hopper_arch() -> bool:
|
||||
# Hopper arch's compute capability == 9.0
|
||||
device = torch.cuda.current_device()
|
||||
|
||||
Reference in New Issue
Block a user