[minor] Improve code style and compatibility (#1961)
This commit is contained in:
@@ -23,6 +23,8 @@ import os
|
||||
import pickle
|
||||
import random
|
||||
import resource
|
||||
import shutil
|
||||
import signal
|
||||
import socket
|
||||
import time
|
||||
import warnings
|
||||
@@ -35,6 +37,7 @@ import psutil
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import triton
|
||||
import zmq
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from packaging import version as pkg_version
|
||||
@@ -379,6 +382,10 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None):
|
||||
if include_self:
|
||||
try:
|
||||
itself.kill()
|
||||
|
||||
# Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
|
||||
# so we send an additional signal to kill them.
|
||||
itself.send_signal(signal.SIGINT)
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
@@ -704,3 +711,44 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint:
|
||||
raise ValueError(f"Unsupported socket type: {socket_type}")
|
||||
|
||||
return socket
|
||||
|
||||
|
||||
def dump_to_file(dirpath, name, value):
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
if get_tensor_model_parallel_rank() != 0:
|
||||
return
|
||||
|
||||
os.makedirs(dirpath, exist_ok=True)
|
||||
if value.dtype is torch.bfloat16:
|
||||
value = value.float()
|
||||
value = value.cpu().numpy()
|
||||
output_filename = os.path.join(dirpath, f"pytorch_dump_{name}.npy")
|
||||
logger.info(f"Dump a tensor to {output_filename}. Shape = {value.shape}")
|
||||
np.save(output_filename, value)
|
||||
|
||||
|
||||
def is_triton_3():
|
||||
return triton.__version__.startswith("3.")
|
||||
|
||||
|
||||
def maybe_torch_compile(*args, **kwargs):
|
||||
"""
|
||||
torch.compile does not work for triton 2.2.0, which is needed in xlm1's jax.
|
||||
Therefore, we disable it here.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
if is_triton_3():
|
||||
return torch.compile(*args, **kwargs)(func)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def delete_directory(dirpath):
|
||||
try:
|
||||
# This will remove the directory and all its contents
|
||||
shutil.rmtree(dirpath)
|
||||
except OSError as e:
|
||||
print(f"Warning: {dirpath} : {e.strerror}")
|
||||
|
||||
Reference in New Issue
Block a user