Optimize broadcast & Reorg code (#1598)
This commit is contained in:
@@ -24,6 +24,7 @@ import random
|
||||
import resource
|
||||
import socket
|
||||
import time
|
||||
import warnings
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
@@ -333,6 +334,10 @@ def suppress_other_loggers():
|
||||
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
||||
logging.getLogger("vllm.utils").setLevel(logging.ERROR)
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore", category=UserWarning, message="The given NumPy array is not writable"
|
||||
)
|
||||
|
||||
|
||||
def assert_pkg_version(pkg: str, min_version: str, message: str):
|
||||
try:
|
||||
@@ -615,7 +620,9 @@ def broadcast_pyobj(
|
||||
else:
|
||||
serialized_data = pickle.dumps(data)
|
||||
size = len(serialized_data)
|
||||
tensor_data = torch.ByteTensor(list(serialized_data))
|
||||
tensor_data = torch.ByteTensor(
|
||||
np.frombuffer(serialized_data, dtype=np.uint8)
|
||||
)
|
||||
tensor_size = torch.tensor([size], dtype=torch.long)
|
||||
|
||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||
|
||||
Reference in New Issue
Block a user