Add device support (#1607)

This commit is contained in:
Zhang, Liangang
2024-10-11 17:05:58 +08:00
committed by GitHub
parent 5476ccad8f
commit 8275049ce3
5 changed files with 96 additions and 52 deletions

View File

@@ -140,26 +140,41 @@ def calculate_time(show=False, min_cost_ms=0.0):
return wrapper
def get_available_gpu_memory(gpu_id, distributed=False):
def get_available_gpu_memory(device, gpu_id, distributed=False):
"""
Get available memory for cuda:gpu_id device.
When distributed is True, the available memory is the minimum available memory of all GPUs.
"""
num_gpus = torch.cuda.device_count()
assert gpu_id < num_gpus
if device == "cuda":
num_gpus = torch.cuda.device_count()
assert gpu_id < num_gpus
if torch.cuda.current_device() != gpu_id:
print(
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
"which may cause useless memory allocation for torch CUDA context.",
)
if torch.cuda.current_device() != gpu_id:
print(
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
"which may cause useless memory allocation for torch CUDA context.",
)
torch.cuda.empty_cache()
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
torch.cuda.empty_cache()
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
elif device == "xpu":
num_gpus = torch.xpu.device_count()
assert gpu_id < num_gpus
if torch.xpu.current_device() != gpu_id:
print(
f"WARNING: current device is not {gpu_id}, but {torch.xpu.current_device()}, ",
"which may cause useless memory allocation for torch XPU context.",
)
torch.xpu.empty_cache()
used_memory = torch.xpu.memory_allocated()
total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory
free_gpu_memory = total_gpu_memory - used_memory
if distributed:
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
torch.device("cuda", gpu_id)
torch.device(device, gpu_id)
)
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
free_gpu_memory = tensor.item()