Optimize broadcast & Reorg code (#1598)

This commit is contained in:
Lianmin Zheng
2024-10-07 13:05:53 -07:00
parent 3ff641132e
commit ebbc42d989
3 changed files with 55 additions and 47 deletions

View File

@@ -24,6 +24,7 @@ import random
import resource
import socket
import time
import warnings
from importlib.metadata import PackageNotFoundError, version
from io import BytesIO
from typing import Any, Dict, List, Optional, Union
@@ -333,6 +334,10 @@ def suppress_other_loggers():
logging.getLogger("vllm.selector").setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.ERROR)
warnings.filterwarnings(
"ignore", category=UserWarning, message="The given NumPy array is not writable"
)
def assert_pkg_version(pkg: str, min_version: str, message: str):
try:
@@ -615,7 +620,9 @@ def broadcast_pyobj(
else:
serialized_data = pickle.dumps(data)
size = len(serialized_data)
tensor_data = torch.ByteTensor(list(serialized_data))
tensor_data = torch.ByteTensor(
np.frombuffer(serialized_data, dtype=np.uint8)
)
tensor_size = torch.tensor([size], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)