114 lines
3.5 KiB
Python
114 lines
3.5 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import math
|
|
from dataclasses import dataclass, field
|
|
from typing import Mapping, Tuple
|
|
|
|
import torch
|
|
|
|
|
|
@dataclass
|
|
class DeviceLimit:
|
|
name: str = "default" # pattern to match from `torch.cuda.get_device_name()`
|
|
source: str = ""
|
|
sm: Tuple[int, int] = (0, 0)
|
|
# bytes/s
|
|
gmem_bandwidth: float = math.inf
|
|
# dtype -> TFlop/s
|
|
gemm_tflops: Mapping[torch.dtype, float] = field(default_factory=dict)
|
|
|
|
|
|
# For f32, we assume we can use tf32
|
|
DEVICE_LIMITS: Tuple[DeviceLimit, ...] = (
|
|
DeviceLimit(
|
|
"H100",
|
|
"https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet", # noqa: E501
|
|
sm=(9, 0),
|
|
gmem_bandwidth=3.35 * (1024**4), # NOTE: PCIe is 2 TB/s
|
|
gemm_tflops={
|
|
torch.float64: 67,
|
|
# NOTE: NVIDIA gives all numbers "with 2:4 sparsity"
|
|
# but we want the full GEMM numbers
|
|
torch.float32: 989 // 2,
|
|
torch.float16: 1979 // 2,
|
|
torch.bfloat16: 1979 // 2,
|
|
torch.int8: 3958 // 2,
|
|
},
|
|
),
|
|
DeviceLimit(
|
|
"A100",
|
|
"https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf", # noqa: E501
|
|
sm=(8, 0),
|
|
gmem_bandwidth=2 * (1024**4), # NOTE: PCIe is 1.5 TB/s
|
|
gemm_tflops={
|
|
torch.float64: 19.5,
|
|
torch.float32: 156,
|
|
torch.float16: 312,
|
|
torch.bfloat16: 312,
|
|
torch.int8: 624,
|
|
},
|
|
),
|
|
DeviceLimit(
|
|
"A30",
|
|
"https://www.nvidia.com/content/dam/en-zz/Solutions/data-center/products/a30-gpu/pdf/a30-datasheet.pdf",
|
|
sm=(8, 0),
|
|
gmem_bandwidth=933 * (1024**3),
|
|
gemm_tflops={
|
|
torch.float64: 10.3,
|
|
torch.float32: 82,
|
|
torch.float16: 165,
|
|
torch.bfloat16: 165,
|
|
torch.int8: 330,
|
|
},
|
|
),
|
|
DeviceLimit(
|
|
"T4",
|
|
"https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf",
|
|
sm=(7, 5),
|
|
gmem_bandwidth=300 * (1024**3),
|
|
gemm_tflops={
|
|
torch.float32: 8.1,
|
|
torch.float16: 65,
|
|
torch.int8: 130,
|
|
},
|
|
),
|
|
# Assuming SXM2
|
|
DeviceLimit(
|
|
"V100",
|
|
"https://images.nvidia.com/content/technologies/volta/pdf/tesla-volta-v100-datasheet-letter-fnl-web.pdf",
|
|
sm=(7, 0),
|
|
gmem_bandwidth=900 * (1024**3),
|
|
gemm_tflops={
|
|
torch.float64: 7.8,
|
|
torch.float32: 15.7,
|
|
torch.float16: 125,
|
|
},
|
|
),
|
|
DeviceLimit(
|
|
"P100",
|
|
"https://images.nvidia.com/content/tesla/pdf/nvidia-tesla-p100-datasheet.pdf",
|
|
sm=(6, 0),
|
|
gmem_bandwidth=732 * (1024**3),
|
|
gemm_tflops={
|
|
torch.float64: 5.3,
|
|
torch.float32: 10.6,
|
|
torch.float16: 21.2,
|
|
},
|
|
),
|
|
)
|
|
|
|
|
|
def get_device_limits(device) -> DeviceLimit:
|
|
"""Currently only implemented for GPUs"""
|
|
if device is not None and device.type == "cuda":
|
|
device_sm = torch.cuda.get_device_capability(device)
|
|
device_name = torch.cuda.get_device_name(device)
|
|
for lim in DEVICE_LIMITS:
|
|
if lim.sm == device_sm:
|
|
if lim.name in device_name:
|
|
return lim
|
|
return DeviceLimit()
|