# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import math from dataclasses import dataclass, field from typing import Mapping, Tuple import torch @dataclass class DeviceLimit: name: str = "default" # pattern to match from `torch.cuda.get_device_name()` source: str = "" sm: Tuple[int, int] = (0, 0) # bytes/s gmem_bandwidth: float = math.inf # dtype -> TFlop/s gemm_tflops: Mapping[torch.dtype, float] = field(default_factory=dict) # For f32, we assume we can use tf32 DEVICE_LIMITS: Tuple[DeviceLimit, ...] = ( DeviceLimit( "H100", "https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet", # noqa: E501 sm=(9, 0), gmem_bandwidth=3.35 * (1024**4), # NOTE: PCIe is 2 TB/s gemm_tflops={ torch.float64: 67, # NOTE: NVIDIA gives all numbers "with 2:4 sparsity" # but we want the full GEMM numbers torch.float32: 989 // 2, torch.float16: 1979 // 2, torch.bfloat16: 1979 // 2, torch.int8: 3958 // 2, }, ), DeviceLimit( "A100", "https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf", # noqa: E501 sm=(8, 0), gmem_bandwidth=2 * (1024**4), # NOTE: PCIe is 1.5 TB/s gemm_tflops={ torch.float64: 19.5, torch.float32: 156, torch.float16: 312, torch.bfloat16: 312, torch.int8: 624, }, ), DeviceLimit( "A30", "https://www.nvidia.com/content/dam/en-zz/Solutions/data-center/products/a30-gpu/pdf/a30-datasheet.pdf", sm=(8, 0), gmem_bandwidth=933 * (1024**3), gemm_tflops={ torch.float64: 10.3, torch.float32: 82, torch.float16: 165, torch.bfloat16: 165, torch.int8: 330, }, ), DeviceLimit( "T4", "https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf", sm=(7, 5), gmem_bandwidth=300 * (1024**3), gemm_tflops={ torch.float32: 8.1, torch.float16: 65, torch.int8: 130, }, ), # Assuming SXM2 DeviceLimit( "V100", "https://images.nvidia.com/content/technologies/volta/pdf/tesla-volta-v100-datasheet-letter-fnl-web.pdf", sm=(7, 0), gmem_bandwidth=900 * (1024**3), gemm_tflops={ torch.float64: 7.8, torch.float32: 15.7, torch.float16: 125, }, ), DeviceLimit( "P100", "https://images.nvidia.com/content/tesla/pdf/nvidia-tesla-p100-datasheet.pdf", sm=(6, 0), gmem_bandwidth=732 * (1024**3), gemm_tflops={ torch.float64: 5.3, torch.float32: 10.6, torch.float16: 21.2, }, ), ) def get_device_limits(device) -> DeviceLimit: """Currently only implemented for GPUs""" if device is not None and device.type == "cuda": device_sm = torch.cuda.get_device_capability(device) device_name = torch.cuda.get_device_name(device) for lim in DEVICE_LIMITS: if lim.sm == device_sm: if lim.name in device_name: return lim return DeviceLimit()