Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

View File

@@ -0,0 +1,380 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
from collections.abc import Sequence
import numpy as np
import pytest
import torch
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import make_tensor_with_pad
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20
MAX_PROMPT_SIZE = 100
CUDA_DEVICES = [
f"{current_platform.device_type}:{i}"
for i in range(min(current_platform.device_count(), 2))
]
MAX_NUM_PROMPT_TOKENS = 64
def _compare_objs(obj1, obj2, skip: Sequence = ("logitsprocs", "batch_update_builder")):
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
attr_names = set(
[a[0] for a in attrs if not (a[0].startswith("__") and a[0].endswith("__"))]
)
for attr_name in attr_names:
if attr_name in skip:
continue
a = getattr(obj1, attr_name)
b = getattr(obj2, attr_name)
is_same = False
if isinstance(a, torch.Tensor):
if a.numel() == 0 or b.numel() == 0:
is_same = a.numel() == 0 and b.numel() == 0
elif torch.allclose(a, b):
is_same = True
elif isinstance(a, np.ndarray):
if np.allclose(a, b):
is_same = True
elif isinstance(a, MultiGroupBlockTable):
for a_i, b_i in zip(a.block_tables, b.block_tables):
_compare_objs(a_i, b_i)
is_same = True
elif isinstance(a, (BlockTable, SamplingMetadata, PoolingMetadata)):
_compare_objs(a, b)
is_same = True # if we make it here must be same
elif a == b:
is_same = True
elif isinstance(a, CpuGpuBuffer):
is_same = np.allclose(a.np, b.np) and torch.allclose(a.gpu, b.gpu)
assert is_same, (
f"Attribute {attr_name} is different in {obj1} and {obj2}: {a} != {b}"
)
def _remove_requests(
input_batch: InputBatch, batch_size: int, reqs: list[CachedRequestState]
) -> set[str]:
"""
Remove some requests randomly from the batch and returns
set of request removed
"""
num_reqs_to_remove = np.random.randint(0, batch_size)
req_indices_to_remove: set[int] = set()
for _ in range(num_reqs_to_remove):
req_index_to_remove = np.random.randint(0, batch_size)
req_indices_to_remove.add(req_index_to_remove)
req_ids_to_remove: set[str] = set()
for index in req_indices_to_remove:
input_batch.remove_request(reqs[index].req_id)
req_ids_to_remove.add(reqs[index].req_id)
return req_ids_to_remove
def _construct_expected_sampling_metadata(
reqs: list[CachedRequestState],
req_ids_retained: set[int],
req_id_index_in_input_batch: dict[str, int],
device: torch.device,
) -> SamplingMetadata:
"""
Constructs and returns the expected SamplingMetadata for this
batch.
"""
num_reqs = len(req_ids_retained)
output_token_ids: list[list[int]] = [list() for _ in range(num_reqs)]
prompt_token_ids: list[list[int]] = [list() for _ in range(num_reqs)]
presence_penalties = [0.0 for _ in range(num_reqs)]
frequency_penalties = [0.0 for _ in range(num_reqs)]
repetition_penalties = [1.0 for _ in range(num_reqs)]
top_k = [0 for _ in range(num_reqs)]
top_p = [0.0 for _ in range(num_reqs)]
temperature = [0.0 for _ in range(num_reqs)]
min_tokens = {}
logit_bias = [None] * num_reqs
allowed_token_ids_mask = torch.zeros(
num_reqs, VOCAB_SIZE, dtype=torch.bool, device=device
)
bad_words_token_ids = {}
for req in reqs:
if req.req_id not in req_ids_retained:
continue
index_in_input_batch = req_id_index_in_input_batch[req.req_id]
output_token_ids[index_in_input_batch] = req.output_token_ids
prompt_token_ids[index_in_input_batch] = req.prompt_token_ids
presence_penalties[index_in_input_batch] = req.sampling_params.presence_penalty
frequency_penalties[index_in_input_batch] = (
req.sampling_params.frequency_penalty
)
repetition_penalties[index_in_input_batch] = (
req.sampling_params.repetition_penalty
)
top_k[index_in_input_batch] = req.sampling_params.top_k
top_p[index_in_input_batch] = req.sampling_params.top_p
temperature[index_in_input_batch] = req.sampling_params.temperature
min_tokens[index_in_input_batch] = (
req.sampling_params.min_tokens,
req.sampling_params.all_stop_token_ids,
)
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
if req.sampling_params.allowed_token_ids:
allowed_token_ids_mask[index_in_input_batch][
req.sampling_params.allowed_token_ids
] = True
if req.sampling_params.bad_words_token_ids:
bad_words_token_ids[index_in_input_batch] = (
req.sampling_params.bad_words_token_ids
)
return SamplingMetadata(
temperature=torch.tensor(temperature, dtype=torch.float, device=device),
all_greedy=False,
all_random=True,
top_p=None
if all(x == 1.0 for x in top_p)
else torch.tensor(top_p, dtype=torch.float, device=device),
top_k=None
if all(x == 0 for x in top_k)
else torch.tensor(top_k, dtype=torch.int, device=device),
generators={},
max_num_logprobs=0,
prompt_token_ids=make_tensor_with_pad(
prompt_token_ids,
pad=VOCAB_SIZE,
device=torch.device(device),
dtype=torch.int64,
),
frequency_penalties=torch.tensor(
frequency_penalties, dtype=torch.float, device=device
),
presence_penalties=torch.tensor(
presence_penalties, dtype=torch.float, device=device
),
repetition_penalties=torch.tensor(
repetition_penalties, dtype=torch.float, device=device
),
output_token_ids=output_token_ids,
spec_token_ids=[[] for _ in range(len(output_token_ids))],
no_penalties=(
all(x == 0 for x in presence_penalties)
and all(x == 0 for x in frequency_penalties)
and all(x == 1 for x in repetition_penalties)
),
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=bad_words_token_ids,
logitsprocs=LogitsProcessors(),
)
def _create_sampling_params():
return SamplingParams(
top_k=np.random.randint(1, 10),
top_p=np.random.uniform(0.0, 1.0),
presence_penalty=np.random.uniform(-2.0, 2.0),
repetition_penalty=np.random.uniform(0.0, 2.0),
frequency_penalty=np.random.uniform(-2.0, 2.0),
min_tokens=np.random.randint(1, 10),
stop_token_ids=[
np.random.randint(0, VOCAB_SIZE) for _ in range(np.random.randint(10))
],
logit_bias={0: np.random.uniform(-3.0, 3.0)},
)
def _construct_cached_request_state(req_id_suffix: int):
prompt_token_ids = [
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(0, MAX_PROMPT_SIZE))
]
output_token_ids = [
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS))
]
return CachedRequestState(
req_id=f"req_id_{req_id_suffix}",
prompt_token_ids=prompt_token_ids,
sampling_params=_create_sampling_params(),
pooling_params=None,
mm_features=[],
block_ids=([],),
generator=None,
num_computed_tokens=len(output_token_ids),
output_token_ids=output_token_ids,
)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32, 64])
def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
"""
Tests the logic for managing sampling metadata in the InputBatch.
This test involves adding a set of requests to the InputBatch,
followed by removing a subset of them. Afterward, the batch is compacted,
and the `make_sampling_metadata` method is invoked on the batch. The
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.
Note: Ignore logits processor logic, which is tested separately
"""
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_sizes=[1],
kernel_block_sizes=[1],
)
reqs: list[CachedRequestState] = []
req_id_reqs = {}
req_id_output_token_ids = {}
# Add requests
for req_index in range(batch_size):
req: CachedRequestState = _construct_cached_request_state(req_index)
assigned_req_index = input_batch.add_request(req)
assert req_index == assigned_req_index
reqs.append(req)
req_id_reqs[req.req_id] = req
req_id_output_token_ids[req.req_id] = req.output_token_ids
# Remove some requests
req_ids_to_remove = _remove_requests(input_batch, batch_size, reqs)
req_ids_retained = set(req_id_reqs.keys()) - req_ids_to_remove
# Compact the input batch
input_batch.condense()
# Generate the sampling metadata
sampling_metadata = input_batch._make_sampling_metadata()
# Create expected output.
expected_sampling_metadata = _construct_expected_sampling_metadata(
reqs, req_ids_retained, input_batch.req_id_to_index, device=torch.device(device)
)
def same(t1: torch.Tensor | None, t2: torch.Tensor | None) -> bool:
return (t1 is None and t2 is None) or (
t1 is not None and t2 is not None and torch.allclose(t1, t2)
)
# Assert the actual and expected output.
assert torch.allclose(
expected_sampling_metadata.temperature, sampling_metadata.temperature
)
assert same(expected_sampling_metadata.top_p, sampling_metadata.top_p)
assert same(expected_sampling_metadata.top_k, sampling_metadata.top_k)
assert torch.allclose(
expected_sampling_metadata.frequency_penalties,
sampling_metadata.frequency_penalties,
)
assert torch.allclose(
expected_sampling_metadata.presence_penalties,
sampling_metadata.presence_penalties,
)
assert torch.allclose(
expected_sampling_metadata.repetition_penalties,
sampling_metadata.repetition_penalties,
)
assert torch.allclose(
expected_sampling_metadata.prompt_token_ids, sampling_metadata.prompt_token_ids
)
assert (
expected_sampling_metadata.output_token_ids
== sampling_metadata.output_token_ids
)
assert expected_sampling_metadata.no_penalties == sampling_metadata.no_penalties
if sampling_metadata.allowed_token_ids_mask:
assert torch.allclose(
expected_sampling_metadata.allowed_token_ids_mask,
sampling_metadata.allowed_token_ids_mask,
)
assert (
expected_sampling_metadata.bad_words_token_ids
== sampling_metadata.bad_words_token_ids
)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize("swap_list", [((0, 1),)])
def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: list):
"""
Tests the logic for managing sampling metadata in the InputBatch.
This test involves adding a set of requests to the InputBatch,
followed by removing a subset of them. Afterward, the batch is compacted,
and the `make_sampling_metadata` method is invoked on the batch. The
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.
Note: Ignore logits processor logic, which is tested separately
"""
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_sizes=[1],
kernel_block_sizes=[1],
)
ref_input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_sizes=[1],
kernel_block_sizes=[1],
)
reqs: list[CachedRequestState] = []
req_id_reqs = {}
req_id_output_token_ids = {}
# Add requests
for req_index in range(batch_size):
req: CachedRequestState = _construct_cached_request_state(req_index)
assigned_req_index = input_batch.add_request(req)
assert assigned_req_index == req_index
reqs.append(req)
req_id_reqs[req.req_id] = req
req_id_output_token_ids[req.req_id] = req.output_token_ids
reordered_reqs = reqs.copy()
for swap_pair in swap_list:
reordered_reqs[swap_pair[0]], reordered_reqs[swap_pair[1]] = (
reordered_reqs[swap_pair[1]],
reordered_reqs[swap_pair[0]],
)
input_batch.swap_states(swap_pair[0], swap_pair[1])
for req_index in range(batch_size):
req = reordered_reqs[req_index]
assigned_req_index = ref_input_batch.add_request(req)
assert assigned_req_index == req_index
input_batch.refresh_metadata()
ref_input_batch.refresh_metadata()
_compare_objs(input_batch, ref_input_batch)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,204 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.config import ProfilerConfig
from vllm.profiler.wrapper import WorkerProfiler
class ConcreteWorkerProfiler(WorkerProfiler):
"""
A basic implementation of a worker profiler for testing purposes.
"""
def __init__(self, profiler_config: ProfilerConfig):
self.start_call_count = 0
self.stop_call_count = 0
self.should_fail_start = False
super().__init__(profiler_config)
def _start(self) -> None:
if self.should_fail_start:
raise RuntimeError("Simulated start failure")
self.start_call_count += 1
def _stop(self) -> None:
self.stop_call_count += 1
@pytest.fixture
def default_profiler_config():
return ProfilerConfig(
profiler="torch",
torch_profiler_dir="/tmp/mock",
delay_iterations=0,
max_iterations=0,
)
def test_immediate_start_stop(default_profiler_config):
"""Test standard start without delay."""
profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.start()
assert profiler._running is True
assert profiler._active is True
assert profiler.start_call_count == 1
profiler.stop()
assert profiler._running is False
assert profiler._active is False
assert profiler.stop_call_count == 1
def test_delayed_start(default_profiler_config):
"""Test that profiler waits for N steps before actually starting."""
default_profiler_config.delay_iterations = 2
profiler = ConcreteWorkerProfiler(default_profiler_config)
# User requests start
profiler.start()
# Should be active (request accepted) but not running (waiting for delay)
assert profiler._active is True
assert profiler._running is False
assert profiler.start_call_count == 0
# Step 1
profiler.step()
assert profiler._running is False
# Step 2 (Threshold reached)
profiler.step()
assert profiler._running is True
assert profiler.start_call_count == 1
def test_max_iterations(default_profiler_config):
"""Test that profiler stops automatically after max iterations."""
default_profiler_config.max_iterations = 2
profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.start()
assert profiler._running is True
# Iteration 1
profiler.step() # profiling_count becomes 1
assert profiler._running is True
# Iteration 2
profiler.step() # profiling_count becomes 2
assert profiler._running is True
# Iteration 3 (Exceeds max)
profiler.step() # profiling_count becomes 3
# Should have stopped now
assert profiler._running is False
assert profiler.stop_call_count == 1
def test_delayed_start_and_max_iters(default_profiler_config):
"""Test combined delayed start and max iterations."""
default_profiler_config.delay_iterations = 2
default_profiler_config.max_iterations = 2
profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.start()
# Step 1
profiler.step()
assert profiler._running is False
assert profiler._active is True
# Step 2 (Starts now)
profiler.step()
assert profiler._profiling_for_iters == 1
assert profiler._running is True
assert profiler._active is True
# Next iteration
profiler.step()
assert profiler._profiling_for_iters == 2
assert profiler._running is True
# Iteration 2 (exceeds max)
profiler.step()
# Should have stopped now
assert profiler._running is False
assert profiler.stop_call_count == 1
def test_idempotency(default_profiler_config):
"""Test that calling start/stop multiple times doesn't break logic."""
profiler = ConcreteWorkerProfiler(default_profiler_config)
# Double Start
profiler.start()
profiler.start()
assert profiler.start_call_count == 1 # Should only start once
# Double Stop
profiler.stop()
profiler.stop()
assert profiler.stop_call_count == 1 # Should only stop once
def test_step_inactive(default_profiler_config):
"""Test that stepping while inactive does nothing."""
default_profiler_config.delay_iterations = 2
profiler = ConcreteWorkerProfiler(default_profiler_config)
# Not started yet
profiler.step()
profiler.step()
# Even though we stepped 2 times, start shouldn't happen because active=False
assert profiler.start_call_count == 0
def test_start_failure(default_profiler_config):
"""Test behavior when the underlying _start method raises exception."""
profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.should_fail_start = True
profiler.start()
# Exception caught in _call_start
assert profiler._running is False # Should not mark as running
assert profiler._active is True # Request is still considered active
assert profiler.start_call_count == 0 # Logic failed inside start
def test_shutdown(default_profiler_config):
"""Test that shutdown calls stop only if running."""
profiler = ConcreteWorkerProfiler(default_profiler_config)
# Case 1: Not running
profiler.shutdown()
assert profiler.stop_call_count == 0
# Case 2: Running
profiler.start()
profiler.shutdown()
assert profiler.stop_call_count == 1
def test_mixed_delay_and_stop(default_profiler_config):
"""Test manual stop during the delay period."""
default_profiler_config.delay_iterations = 5
profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.start()
profiler.step()
profiler.step()
# User cancels before delay finishes
profiler.stop()
assert profiler._active is False
# Further steps should not trigger start
profiler.step()
profiler.step()
profiler.step()
assert profiler.start_call_count == 0

View File

@@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.v1.worker.utils import bind_kv_cache
def test_bind_kv_cache():
from vllm.attention.layer import Attention
ctx = {
"layers.0.self_attn": Attention(32, 128, 0.1),
"layers.1.self_attn": Attention(32, 128, 0.1),
"layers.2.self_attn": Attention(32, 128, 0.1),
"layers.3.self_attn": Attention(32, 128, 0.1),
}
kv_cache = {
"layers.0.self_attn": torch.zeros((1,)),
"layers.1.self_attn": torch.zeros((1,)),
"layers.2.self_attn": torch.zeros((1,)),
"layers.3.self_attn": torch.zeros((1,)),
}
runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache["layers.0.self_attn"]
assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache["layers.1.self_attn"]
assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache["layers.2.self_attn"]
assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache["layers.3.self_attn"]
assert runner_kv_caches[0] is kv_cache["layers.0.self_attn"]
assert runner_kv_caches[1] is kv_cache["layers.1.self_attn"]
assert runner_kv_caches[2] is kv_cache["layers.2.self_attn"]
assert runner_kv_caches[3] is kv_cache["layers.3.self_attn"]
def test_bind_kv_cache_non_attention():
from vllm.attention.layer import Attention
# example from Jamba PP=2
ctx = {
"model.layers.20.attn": Attention(32, 128, 0.1),
"model.layers.28.attn": Attention(32, 128, 0.1),
}
kv_cache = {
"model.layers.20.attn": torch.zeros((1,)),
"model.layers.28.attn": torch.zeros((1,)),
}
runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
assert ctx["model.layers.20.attn"].kv_cache[0] is kv_cache["model.layers.20.attn"]
assert ctx["model.layers.28.attn"].kv_cache[0] is kv_cache["model.layers.28.attn"]
assert runner_kv_caches[0] is kv_cache["model.layers.20.attn"]
assert runner_kv_caches[1] is kv_cache["model.layers.28.attn"]

View File

@@ -0,0 +1,189 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import multiprocessing as mp
import os
import tempfile
from multiprocessing.queues import Queue
from unittest.mock import patch
import pytest
import torch
from vllm.engine.arg_utils import EngineArgs
from vllm.utils.mem_utils import MemorySnapshot
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
# Global queue to track operation order across processes
_QUEUE: Queue | None = None
def track_operation(operation: str, rank: int):
"""Track when an operation happens and its rank."""
if _QUEUE is not None:
_QUEUE.put((operation, rank))
def make_operation_tracker(operation_name: str, original_func):
"""Create a mock function that tracks when an operation is called.
Args:
operation_name: Name to use when tracking this operation
original_func: The original function to wrap
Returns:
A wrapper function that tracks the operation and calls the original
"""
def wrapper(*args, **kwargs):
rank = int(os.environ.get("RANK", "-1"))
track_operation(operation_name, rank)
return original_func(*args, **kwargs)
return wrapper
def worker_process(
rank: int,
world_size: int,
distributed_init_method: str,
queue: Queue,
error_queue: Queue,
):
"""Worker process that initializes a GPU worker with proper tracking."""
global _QUEUE
_QUEUE = queue
try:
# Set environment variables
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
# Create vLLM config with small model
vllm_config = EngineArgs(
model="facebook/opt-125m", tensor_parallel_size=2, load_format="dummy"
).create_engine_config()
# Create worker
worker = Worker(
vllm_config=vllm_config,
local_rank=rank,
rank=rank,
distributed_init_method=distributed_init_method,
)
# Get original functions before patching
original_init_worker = init_worker_distributed_environment
original_memory_snapshot_init = MemorySnapshot.__init__
original_all_reduce = torch.distributed.all_reduce
# Apply minimal patches to track operation order
init_patch = patch(
"vllm.v1.worker.gpu_worker.init_worker_distributed_environment",
side_effect=make_operation_tracker(
"init_distributed", original_init_worker
),
)
memory_patch = patch.object(
MemorySnapshot,
"__init__",
make_operation_tracker("memory_snapshot", original_memory_snapshot_init),
)
all_reduce_patch = patch(
"torch.distributed.all_reduce",
side_effect=make_operation_tracker("nccl_all_reduce", original_all_reduce),
)
with init_patch, memory_patch, all_reduce_patch:
# Initialize device (this is where we test the order)
worker.init_device()
# Load model to ensure everything works
worker.load_model()
# Signal success
queue.put(("success", rank))
except Exception as e:
error_queue.put((rank, str(e), type(e).__name__))
raise
@pytest.mark.skipif(
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs for tensor parallelism"
)
def test_init_distributed_is_called_before_memory_snapshot():
"""Test that distributed env is setup before memory snapshot.
This test makes sure during worker initialization, the initial memory
snapshot is taken after distributed env is setup to include all the buffers
allocated by distributed env.
"""
world_size = 2
# Create a temporary file for distributed init
with tempfile.NamedTemporaryFile(delete=False) as f:
distributed_init_method = f"file://{f.name}"
# Create queues for inter-process communication
ctx = mp.get_context("spawn")
operation_queue = ctx.Queue()
error_queue = ctx.Queue()
# Start worker processes
processes = []
for rank in range(world_size):
p = ctx.Process(
target=worker_process,
args=(
rank,
world_size,
distributed_init_method,
operation_queue,
error_queue,
),
)
p.start()
processes.append(p)
# Wait for all processes to complete
for p in processes:
p.join(timeout=60) # 60 second timeout
# Check for errors
errors = []
while not error_queue.empty():
rank, error_msg, error_type = error_queue.get()
errors.append(f"Rank {rank}: {error_type}: {error_msg}")
if errors:
pytest.fail("Worker processes failed:\n" + "\n".join(errors))
# Collect all operations from the queue
operations = []
while not operation_queue.empty():
operations.append(operation_queue.get())
# Verify we got operations from both ranks
print(f"Collected operations: {operations}")
# Check operations for each rank
for rank in range(world_size):
rank_ops = [op for op, r in operations if r == rank]
print(f"\nRank {rank} operations: {rank_ops}")
# Raises ValueError if the operation is not found
init_distributed = rank_ops.index("init_distributed")
nccl_all_reduce = rank_ops.index("nccl_all_reduce")
memory_snapshot = rank_ops.index("memory_snapshot")
# Verify order: init_distributed should happen before memory_snapshot
assert init_distributed < nccl_all_reduce < memory_snapshot, (
f"Rank {rank}: init_distributed (index {init_distributed}) "
f"must happen before nccl_all_reduce (index {nccl_all_reduce}) "
f"and memory_snapshot (index {memory_snapshot})"
)
# Clean up
os.unlink(distributed_init_method.replace("file://", ""))