Improve error handling (#433)

This commit is contained in:
Lianmin Zheng
2024-05-12 20:49:04 -07:00
committed by GitHub
parent 04c0b21488
commit 562b8857d8
6 changed files with 92 additions and 41 deletions

View File

@@ -2,7 +2,9 @@
import base64
import json
import sys
import threading
import traceback
import urllib.request
from io import BytesIO
from json import dumps
@@ -10,32 +12,10 @@ from json import dumps
import requests
def get_available_gpu_memory(gpu_id, distributed=True):
"""
Get available memory for cuda:gpu_id device.
When distributed is True, the available memory is the minimum available memory of all GPUs.
"""
import torch
num_gpus = torch.cuda.device_count()
assert gpu_id < num_gpus
if torch.cuda.current_device() != gpu_id:
print(
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
"which may cause useless memory allocation for torch CUDA context.",
)
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
if distributed:
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
torch.device("cuda", gpu_id)
)
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
free_gpu_memory = tensor.item()
return free_gpu_memory / (1 << 30)
def get_exception_traceback():
etype, value, tb = sys.exc_info()
err_str = "".join(traceback.format_exception(etype, value, tb))
return err_str
def is_same_type(values):
@@ -190,4 +170,4 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
if not ret_value:
raise RuntimeError()
return ret_value[0]
return ret_value[0]