[minor] Improve code style and compatibility (#1961)

This commit is contained in:
Lianmin Zheng
2024-11-08 02:19:41 -08:00
committed by GitHub
parent 7ef0084b0d
commit a509552087
6 changed files with 109 additions and 35 deletions

View File

@@ -23,6 +23,8 @@ import os
import pickle
import random
import resource
import shutil
import signal
import socket
import time
import warnings
@@ -35,6 +37,7 @@ import psutil
import requests
import torch
import torch.distributed as dist
import triton
import zmq
from fastapi.responses import ORJSONResponse
from packaging import version as pkg_version
@@ -379,6 +382,10 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None):
if include_self:
try:
itself.kill()
# Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
# so we send an additional signal to kill them.
itself.send_signal(signal.SIGINT)
except psutil.NoSuchProcess:
pass
@@ -704,3 +711,44 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint:
raise ValueError(f"Unsupported socket type: {socket_type}")
return socket
def dump_to_file(dirpath, name, value):
from vllm.distributed import get_tensor_model_parallel_rank
if get_tensor_model_parallel_rank() != 0:
return
os.makedirs(dirpath, exist_ok=True)
if value.dtype is torch.bfloat16:
value = value.float()
value = value.cpu().numpy()
output_filename = os.path.join(dirpath, f"pytorch_dump_{name}.npy")
logger.info(f"Dump a tensor to {output_filename}. Shape = {value.shape}")
np.save(output_filename, value)
def is_triton_3():
return triton.__version__.startswith("3.")
def maybe_torch_compile(*args, **kwargs):
"""
torch.compile does not work for triton 2.2.0, which is needed in xlm1's jax.
Therefore, we disable it here.
"""
def decorator(func):
if is_triton_3():
return torch.compile(*args, **kwargs)(func)
return func
return decorator
def delete_directory(dirpath):
try:
# This will remove the directory and all its contents
shutil.rmtree(dirpath)
except OSError as e:
print(f"Warning: {dirpath} : {e.strerror}")