@@ -1,81 +0,0 @@
|
|||||||
import datetime
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from torch.distributed.device_mesh import init_device_mesh
|
|
||||||
|
|
||||||
from sglang.srt.entrypoints.verl_engine import VerlEngine
|
|
||||||
|
|
||||||
|
|
||||||
def run():
|
|
||||||
"""
|
|
||||||
Example command:
|
|
||||||
```
|
|
||||||
torchrun --nproc_per_node=8 offline_batch_inference_torchrun.py
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
local_rank = int(os.environ["LOCAL_RANK"])
|
|
||||||
rank = int(os.environ["RANK"])
|
|
||||||
world_size = int(os.environ["WORLD_SIZE"])
|
|
||||||
|
|
||||||
def _log(text):
|
|
||||||
t = datetime.datetime.now().strftime("%H:%M:%S")
|
|
||||||
print(f"[{t}] [rank={rank}] {text}")
|
|
||||||
|
|
||||||
_log(
|
|
||||||
f'start {local_rank=} {rank=} {world_size=} {sys.argv=} {os.environ.get("CUDA_VISIBLE_DEVICES")}'
|
|
||||||
)
|
|
||||||
|
|
||||||
tp_size = 4
|
|
||||||
dp_size = 2
|
|
||||||
assert world_size == tp_size * dp_size
|
|
||||||
|
|
||||||
device_mesh_kwargs = dict(
|
|
||||||
mesh_shape=(tp_size, dp_size, 1), mesh_dim_names=["tp", "dp", "pp"]
|
|
||||||
)
|
|
||||||
device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs)
|
|
||||||
_log(f"{device_mesh_cpu=}")
|
|
||||||
|
|
||||||
tp_rank = device_mesh_cpu.get_local_rank("tp")
|
|
||||||
dp_rank = device_mesh_cpu.get_local_rank("dp")
|
|
||||||
_log(f"{tp_rank=} {tp_size=} ; {dp_rank=} {dp_size=}")
|
|
||||||
|
|
||||||
model_name, mem_fraction_static = "meta-llama/Llama-3.2-1B-Instruct", 0.1
|
|
||||||
# model_name, mem_fraction_static = "meta-llama/Llama-3.1-70B-Instruct", 0.9 # test large models
|
|
||||||
# model_name, mem_fraction_static = "deepseek-ai/DeepSeek-V2-Lite", 0.8
|
|
||||||
|
|
||||||
for k in ["TORCHELASTIC_USE_AGENT_STORE"]:
|
|
||||||
if k in os.environ:
|
|
||||||
del os.environ[k]
|
|
||||||
|
|
||||||
fragment = VerlEngine(
|
|
||||||
model_path=model_name,
|
|
||||||
mem_fraction_static=mem_fraction_static,
|
|
||||||
device_mesh_cpu=device_mesh_cpu["tp"],
|
|
||||||
base_gpu_id=dp_rank,
|
|
||||||
gpu_id_step=dp_size,
|
|
||||||
port=30000,
|
|
||||||
# for DeepSeek-V2-Lite + DP Attention
|
|
||||||
# enable_dp_attention=True, port=30000 + dp_rank * 100,
|
|
||||||
)
|
|
||||||
_log(f"{fragment=}")
|
|
||||||
|
|
||||||
prompt_all = [
|
|
||||||
["1+1=2, 1+2=3, 1+3=4, 1+4=", "9-1=8, 8-1=7, 7-1="],
|
|
||||||
["2*1=2, 2*2=4, 2*3=", "8/2=4, 6/2="],
|
|
||||||
]
|
|
||||||
prompt = prompt_all[dp_rank]
|
|
||||||
|
|
||||||
output = fragment.generate(
|
|
||||||
prompt=prompt,
|
|
||||||
sampling_params=dict(max_new_tokens=16, temperature=0.0),
|
|
||||||
)
|
|
||||||
_log(f"{prompt=} {output=}")
|
|
||||||
|
|
||||||
fragment.shutdown()
|
|
||||||
_log(f"End script")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
run()
|
|
||||||
@@ -64,11 +64,9 @@ class HttpServerEngineAdapter(EngineBase):
|
|||||||
|
|
||||||
def _make_request(self, endpoint: str, payload: Optional[dict] = None):
|
def _make_request(self, endpoint: str, payload: Optional[dict] = None):
|
||||||
"""Make a POST request to the specified endpoint with the given payload.
|
"""Make a POST request to the specified endpoint with the given payload.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
endpoint: The API endpoint to call
|
endpoint: The API endpoint to call
|
||||||
payload: The JSON payload to send (default: empty dict)
|
payload: The JSON payload to send (default: empty dict)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The JSON response from the server
|
The JSON response from the server
|
||||||
"""
|
"""
|
||||||
@@ -85,7 +83,6 @@ class HttpServerEngineAdapter(EngineBase):
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Update model weights from tensor data. The HTTP server will only post meta data, and the real weights will be copied directly from GPUs.
|
Update model weights from tensor data. The HTTP server will only post meta data, and the real weights will be copied directly from GPUs.
|
||||||
|
|
||||||
Note: The model should be on GPUs rather than CPU for this functionality to work properly.
|
Note: The model should be on GPUs rather than CPU for this functionality to work properly.
|
||||||
If you encounter issues, ensure your model is loaded on GPU devices rather than CPU.
|
If you encounter issues, ensure your model is loaded on GPU devices rather than CPU.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,179 +0,0 @@
|
|||||||
# Copyright 2023-2024 SGLang Team
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
import os
|
|
||||||
from typing import Dict, Iterable, List, Literal, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from PIL.Image import Image
|
|
||||||
from torch.distributed.tensor import DeviceMesh, DTensor
|
|
||||||
|
|
||||||
from sglang.srt.entrypoints.engine import Engine
|
|
||||||
from sglang.srt.entrypoints.http_server_engine import HttpServerEngineAdapter
|
|
||||||
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
|
||||||
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
|
||||||
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj
|
|
||||||
|
|
||||||
|
|
||||||
class VerlEngine:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
device_mesh_cpu: DeviceMesh,
|
|
||||||
nnodes: int = 1,
|
|
||||||
backend: Literal["engine", "server"] = "engine",
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
monkey_patch_torch_reductions()
|
|
||||||
self._device_mesh_cpu = device_mesh_cpu
|
|
||||||
self._tp_rank = device_mesh_cpu.get_local_rank()
|
|
||||||
self._rank = device_mesh_cpu.get_rank()
|
|
||||||
self._tp_size = device_mesh_cpu.size()
|
|
||||||
tp_size_per_node = self._tp_size // nnodes
|
|
||||||
node_rank = self._tp_rank // tp_size_per_node
|
|
||||||
first_rank_in_node = self._tp_rank % tp_size_per_node == 0
|
|
||||||
|
|
||||||
# Common engine keyword arguments
|
|
||||||
engine_kwargs = dict(
|
|
||||||
**kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes
|
|
||||||
)
|
|
||||||
|
|
||||||
if backend == "engine":
|
|
||||||
if first_rank_in_node:
|
|
||||||
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
|
|
||||||
self._engine = Engine(**engine_kwargs)
|
|
||||||
else:
|
|
||||||
self._engine = None
|
|
||||||
|
|
||||||
elif backend == "server":
|
|
||||||
if self._tp_rank == 0:
|
|
||||||
self._engine = HttpServerEngineAdapter(**engine_kwargs)
|
|
||||||
else:
|
|
||||||
self._engine = None
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported backend: {backend}")
|
|
||||||
|
|
||||||
dist.barrier(group=self._device_mesh_cpu.get_group())
|
|
||||||
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
|
||||||
prompt: Optional[Union[List[str], str]] = None,
|
|
||||||
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
|
||||||
# The token ids for text; one can either specify text or input_ids.
|
|
||||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
|
||||||
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
|
|
||||||
# Can be formatted as:
|
|
||||||
# - Single image for a single request
|
|
||||||
# - List of images (one per request in a batch)
|
|
||||||
# - List of lists of images (multiple images per request)
|
|
||||||
# See also python/sglang/srt/utils.py:load_image for more details.
|
|
||||||
image_data: Optional[
|
|
||||||
Union[
|
|
||||||
List[List[Union[Image, str]]],
|
|
||||||
List[Union[Image, str]],
|
|
||||||
Union[Image, str],
|
|
||||||
]
|
|
||||||
] = None,
|
|
||||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
|
||||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
|
||||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
|
||||||
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
|
|
||||||
lora_path: Optional[List[Optional[str]]] = None,
|
|
||||||
custom_logit_processor: Optional[Union[List[str], str]] = None,
|
|
||||||
) -> Dict:
|
|
||||||
"""
|
|
||||||
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
|
|
||||||
Please refer to `GenerateReqInput` for the documentation.
|
|
||||||
"""
|
|
||||||
if self._tp_rank == 0:
|
|
||||||
output = self._engine.generate(
|
|
||||||
prompt=prompt,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
input_ids=input_ids,
|
|
||||||
image_data=image_data,
|
|
||||||
return_logprob=return_logprob,
|
|
||||||
logprob_start_len=logprob_start_len,
|
|
||||||
top_logprobs_num=top_logprobs_num,
|
|
||||||
token_ids_logprob=token_ids_logprob,
|
|
||||||
lora_path=lora_path,
|
|
||||||
custom_logit_processor=custom_logit_processor,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
output = None
|
|
||||||
|
|
||||||
# Most naive implementation, can extract tensor and send via gloo if too slow
|
|
||||||
[output] = broadcast_pyobj(
|
|
||||||
data=[output],
|
|
||||||
rank=self._rank,
|
|
||||||
dist_group=self._device_mesh_cpu.get_group(),
|
|
||||||
src=self._device_mesh_cpu.mesh[0].item(),
|
|
||||||
force_cpu_device=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
def update_weights_from_tensor(
|
|
||||||
self,
|
|
||||||
named_tensors: Iterable[Tuple[str, torch.Tensor]],
|
|
||||||
load_format: Optional[str] = None,
|
|
||||||
):
|
|
||||||
# Most naive implementation, can optimize a lot if it is bottleneck
|
|
||||||
for tensor_index, (name, tensor) in enumerate(named_tensors):
|
|
||||||
serialized_tensor = MultiprocessingSerializer.serialize(
|
|
||||||
_preprocess_tensor_for_update_weights(tensor)
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._tp_rank == 0:
|
|
||||||
gathered_serialized_tensors = [None for _ in range(self._tp_size)]
|
|
||||||
else:
|
|
||||||
gathered_serialized_tensors = None
|
|
||||||
dist.gather_object(
|
|
||||||
obj=serialized_tensor,
|
|
||||||
object_gather_list=gathered_serialized_tensors,
|
|
||||||
dst=self._device_mesh_cpu.mesh.tolist()[0],
|
|
||||||
group=self._device_mesh_cpu.get_group(),
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._tp_rank == 0:
|
|
||||||
self._engine.update_weights_from_tensor(
|
|
||||||
named_tensors=[
|
|
||||||
(
|
|
||||||
name,
|
|
||||||
LocalSerializedTensor(values=gathered_serialized_tensors),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
load_format=load_format,
|
|
||||||
flush_cache=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._tp_rank == 0:
|
|
||||||
self._engine.flush_cache()
|
|
||||||
|
|
||||||
def release_memory_occupation(self):
|
|
||||||
if self._tp_rank == 0:
|
|
||||||
self._engine.release_memory_occupation()
|
|
||||||
|
|
||||||
def resume_memory_occupation(self):
|
|
||||||
if self._tp_rank == 0:
|
|
||||||
self._engine.resume_memory_occupation()
|
|
||||||
|
|
||||||
def shutdown(self):
|
|
||||||
if self._engine is not None:
|
|
||||||
self._engine.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
def _preprocess_tensor_for_update_weights(tensor: torch.Tensor):
|
|
||||||
if isinstance(tensor, DTensor):
|
|
||||||
return tensor.full_tensor()
|
|
||||||
return tensor
|
|
||||||
@@ -144,7 +144,6 @@ suites = {
|
|||||||
TestFile("test_moe_ep.py", 181),
|
TestFile("test_moe_ep.py", 181),
|
||||||
TestFile("test_patch_torch.py", 19),
|
TestFile("test_patch_torch.py", 19),
|
||||||
TestFile("test_update_weights_from_distributed.py", 103),
|
TestFile("test_update_weights_from_distributed.py", 103),
|
||||||
TestFile("test_verl_engine_2_gpu.py", 64),
|
|
||||||
TestFile("test_release_memory_occupation.py", 44),
|
TestFile("test_release_memory_occupation.py", 44),
|
||||||
],
|
],
|
||||||
"per-commit-2-gpu-amd": [
|
"per-commit-2-gpu-amd": [
|
||||||
@@ -157,7 +156,6 @@ suites = {
|
|||||||
"per-commit-4-gpu": [
|
"per-commit-4-gpu": [
|
||||||
TestFile("test_local_attn.py", 250),
|
TestFile("test_local_attn.py", 250),
|
||||||
TestFile("test_pp_single_node.py", 150),
|
TestFile("test_pp_single_node.py", 150),
|
||||||
TestFile("test_verl_engine_4_gpu.py", 64),
|
|
||||||
],
|
],
|
||||||
"per-commit-4-gpu-amd": [
|
"per-commit-4-gpu-amd": [
|
||||||
TestFile("test_pp_single_node.py", 150),
|
TestFile("test_pp_single_node.py", 150),
|
||||||
|
|||||||
@@ -1,415 +0,0 @@
|
|||||||
import multiprocessing
|
|
||||||
import multiprocessing as mp
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
import unittest
|
|
||||||
from multiprocessing import Process
|
|
||||||
|
|
||||||
import requests
|
|
||||||
import torch
|
|
||||||
from openai import OpenAI
|
|
||||||
from torch.distributed.device_mesh import init_device_mesh
|
|
||||||
from torch.distributed.fsdp import CPUOffload
|
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
||||||
from torch.distributed.fsdp import MixedPrecision
|
|
||||||
from torch.distributed.fsdp.api import (
|
|
||||||
ShardedStateDictConfig,
|
|
||||||
ShardingStrategy,
|
|
||||||
StateDictType,
|
|
||||||
)
|
|
||||||
from transformers import AutoModelForCausalLM
|
|
||||||
|
|
||||||
from sglang.srt.entrypoints.verl_engine import VerlEngine
|
|
||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
|
||||||
from sglang.srt.server_args import ServerArgs
|
|
||||||
from sglang.srt.utils import is_port_available
|
|
||||||
from sglang.test.runners import (
|
|
||||||
HFRunner,
|
|
||||||
SRTRunner,
|
|
||||||
check_close_model_outputs,
|
|
||||||
get_dtype_str,
|
|
||||||
)
|
|
||||||
from sglang.test.test_utils import CustomTestCase, find_available_port, is_in_ci
|
|
||||||
|
|
||||||
_MAX_NEW_TOKENS = 8
|
|
||||||
_PROMPTS = ["1+1=2, 1+2=3, 1+3=4, 1+4=5, 1+5=", "1*1=1, 1*2=2, 1*3=3, 1*4=4, 1*5="]
|
|
||||||
_TORCH_DTYPE = torch.float16
|
|
||||||
|
|
||||||
# Set to false to temporarily debug issues unrelated to weight update
|
|
||||||
_ENABLE_UPDATE_WEIGHTS = True
|
|
||||||
|
|
||||||
CI_MODELS = [
|
|
||||||
dict(model_path="meta-llama/Llama-3.1-8B-Instruct"),
|
|
||||||
# Fail to run gemma-2-2b after transformers==4.48.3 -> 4.50.0
|
|
||||||
# dict(model_path="google/gemma-2-2b"),
|
|
||||||
]
|
|
||||||
ALL_OTHER_MODELS = [
|
|
||||||
dict(model_path="meta-llama/Llama-3.2-1B-Instruct", tp_size=1),
|
|
||||||
dict(model_path="Qwen/Qwen2-1.5B"),
|
|
||||||
# dict(
|
|
||||||
# model_path="Qwen/Qwen2.5-14B-Instruct",
|
|
||||||
# mem_fraction_static=0.4,
|
|
||||||
# tp_size=8,
|
|
||||||
# tight_memory=True,
|
|
||||||
# decode_tolerance=1.3,
|
|
||||||
# ), # test_generation_models.py same config (qwen + tp=8) gives 1.22 decode error
|
|
||||||
dict(model_path="HuggingFaceTB/SmolLM-135M-Instruct", tp_size=3),
|
|
||||||
# dict(model_path="allenai/OLMo-1B-0724-hf"),
|
|
||||||
# dict(
|
|
||||||
# model_path="THUDM/glm-4-9b-chat",
|
|
||||||
# mem_fraction_static=0.1,
|
|
||||||
# tp_size=8,
|
|
||||||
# tight_memory=True,
|
|
||||||
# ),
|
|
||||||
# dict(model_path="allenai/OLMo-2-1124-7B-Instruct"),
|
|
||||||
# dict(
|
|
||||||
# model_path="ibm-granite/granite-3.0-2b-instruct",
|
|
||||||
# prefill_tolerance=0.22,
|
|
||||||
# decode_tolerance=0.22,
|
|
||||||
# ),
|
|
||||||
]
|
|
||||||
|
|
||||||
# This port is used for HTTP API communication with the VerlEngine server
|
|
||||||
# It handles client requests for text generation, weight updates, and memory management
|
|
||||||
# This port must be available and not used by other processes
|
|
||||||
PORT = find_available_port(2345)
|
|
||||||
|
|
||||||
# Master port is used for PyTorch's distributed communication setup
|
|
||||||
# It enables tensor-parallel processes to communicate with each other
|
|
||||||
# Default is 23456, but we find an available port dynamically in assert_fragment_e2e_execution
|
|
||||||
# This port is critical for torch.distributed.init_process_group to function properly
|
|
||||||
# Each test needs a unique master_port to avoid conflicts between parallel test executions
|
|
||||||
# master_port = find_available_port(23456) # This is set in assert_fragment_e2e_execution method
|
|
||||||
|
|
||||||
|
|
||||||
class TestVerlEngine(CustomTestCase):
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
multiprocessing.set_start_method("spawn")
|
|
||||||
|
|
||||||
def assert_fragment_e2e_execution(
|
|
||||||
self,
|
|
||||||
index: int,
|
|
||||||
model_path: str,
|
|
||||||
mem_fraction_static: float = 0.4,
|
|
||||||
tp_size: int = 2,
|
|
||||||
tight_memory: bool = False,
|
|
||||||
prefill_tolerance: float = 0.1,
|
|
||||||
decode_tolerance: float = 0.1,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Tests VerlEngine with tensor parallelism across multiple processes.
|
|
||||||
|
|
||||||
Spawns tp_size processes to test distributed execution, including:
|
|
||||||
- Model inference via direct API and HTTP server
|
|
||||||
- Weight updating functionality
|
|
||||||
- Memory management (release/resume)
|
|
||||||
|
|
||||||
The test validates output correctness against a reference implementation
|
|
||||||
within specified tolerance bounds.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
index: int - Test index for logging
|
|
||||||
model_path: str - HuggingFace model identifier
|
|
||||||
mem_fraction_static: float - Memory fraction for static tensors
|
|
||||||
tp_size: int - Number of tensor parallel processes
|
|
||||||
tight_memory: bool - Enable memory optimization
|
|
||||||
prefill_tolerance: float - Max error for prefill computation
|
|
||||||
decode_tolerance: float - Max error for decoding computation
|
|
||||||
"""
|
|
||||||
|
|
||||||
master_port = find_available_port(23456)
|
|
||||||
|
|
||||||
print(f"assert_fragment_e2e_execution START {index=} {model_path=}")
|
|
||||||
|
|
||||||
processes = []
|
|
||||||
output_reader, output_writer = mp.Pipe(duplex=False)
|
|
||||||
for tp_rank in range(tp_size):
|
|
||||||
p = Process(
|
|
||||||
target=_run_subprocess,
|
|
||||||
kwargs=dict(
|
|
||||||
tp_rank=tp_rank,
|
|
||||||
tp_size=tp_size,
|
|
||||||
master_port=master_port,
|
|
||||||
output_writer=output_writer,
|
|
||||||
model_path=model_path,
|
|
||||||
mem_fraction_static=mem_fraction_static,
|
|
||||||
tight_memory=tight_memory,
|
|
||||||
prefill_tolerance=prefill_tolerance,
|
|
||||||
decode_tolerance=decode_tolerance,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
p.start()
|
|
||||||
processes.append(p)
|
|
||||||
|
|
||||||
for _ in range(tp_size):
|
|
||||||
self.assertTrue(
|
|
||||||
output_reader.recv(),
|
|
||||||
f"Subprocess has error, please see logs above. ({index=} {model_path=})",
|
|
||||||
)
|
|
||||||
|
|
||||||
for p in processes:
|
|
||||||
p.join()
|
|
||||||
|
|
||||||
def test_models(self):
|
|
||||||
"""
|
|
||||||
Orchestrates end-to-end testing across configured model sets.
|
|
||||||
|
|
||||||
In CI environments: Randomly selects one model for faster testing.
|
|
||||||
In development: Tests all configured models for comprehensive validation.
|
|
||||||
|
|
||||||
Each model configuration specifies model path, memory settings,
|
|
||||||
tensor-parallel size, and error tolerance bounds.
|
|
||||||
"""
|
|
||||||
test_models = ALL_OTHER_MODELS
|
|
||||||
if is_in_ci():
|
|
||||||
# Randomly select one model in CI for faster testing
|
|
||||||
test_models = [random.choice(ALL_OTHER_MODELS)]
|
|
||||||
# Test all models in development environment
|
|
||||||
print(f"Development environment: Testing all {len(ALL_OTHER_MODELS)} models")
|
|
||||||
for index, model_info in enumerate(test_models):
|
|
||||||
self.assert_fragment_e2e_execution(index=index, **model_info)
|
|
||||||
|
|
||||||
|
|
||||||
def _run_subprocess(
|
|
||||||
tp_rank: int,
|
|
||||||
tp_size: int,
|
|
||||||
master_port: int,
|
|
||||||
output_writer,
|
|
||||||
model_path: str,
|
|
||||||
mem_fraction_static: float,
|
|
||||||
tight_memory: bool,
|
|
||||||
prefill_tolerance: float,
|
|
||||||
decode_tolerance: float,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Executes a single tensor-parallel process for testing VerlEngine.
|
|
||||||
|
|
||||||
Performs the core test operations:
|
|
||||||
1. Initializes distributed environment
|
|
||||||
2. Loads HuggingFace model for reference
|
|
||||||
3. Tests VerlEngine API (generation, memory management, weight updates)
|
|
||||||
4. Tests OpenAI-compatible endpoints on rank 0
|
|
||||||
|
|
||||||
Reports success/failure via output_writer pipe.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
tp_rank: int - Process rank in tensor parallel group
|
|
||||||
tp_size: int - Total processes in tensor parallel group
|
|
||||||
master_port: int - Port for distributed communication
|
|
||||||
output_writer - Pipe for result communication
|
|
||||||
model_path: str - HuggingFace model identifier
|
|
||||||
mem_fraction_static: float - Static memory allocation fraction
|
|
||||||
tight_memory: bool - Memory optimization flag
|
|
||||||
prefill_tolerance: float - Acceptable prefill error
|
|
||||||
decode_tolerance: float - Acceptable decode error
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
print(f"subprocess[{tp_rank=}] Start {os.environ.get('CUDA_VISIBLE_DEVICES')=}")
|
|
||||||
|
|
||||||
os.environ["MASTER_ADDR"] = "localhost"
|
|
||||||
os.environ["MASTER_PORT"] = str(master_port)
|
|
||||||
torch.distributed.init_process_group(rank=tp_rank, world_size=tp_size)
|
|
||||||
torch.cuda.set_device(tp_rank)
|
|
||||||
|
|
||||||
mesh_kwargs = dict(mesh_shape=(tp_size, 1), mesh_dim_names=["tp", "pp"])
|
|
||||||
inference_device_mesh_device = init_device_mesh("cuda", **mesh_kwargs)
|
|
||||||
inference_device_mesh_cpu = init_device_mesh("cpu", **mesh_kwargs)
|
|
||||||
# Print basic information about this subprocess including:
|
|
||||||
# - Current tensor-parallel rank
|
|
||||||
# - Device mesh configuration for both CUDA and CPU
|
|
||||||
# - This subprocess's role in testing tensor-parallel execution
|
|
||||||
# - How it contributes to the distributed model testing
|
|
||||||
print(
|
|
||||||
f"subprocess[{tp_rank=}] initialized for VerlEngine testing - "
|
|
||||||
f"Role: Shard {tp_rank+1}/{tp_size} of tensor-parallel model execution | "
|
|
||||||
f"Device meshes: CUDA={inference_device_mesh_device}, CPU={inference_device_mesh_cpu}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# hf model is used for comparison
|
|
||||||
hf_model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_path, torch_dtype=_TORCH_DTYPE, trust_remote_code=True
|
|
||||||
).cuda()
|
|
||||||
hf_tokenizer = get_tokenizer(model_path, trust_remote_code=True)
|
|
||||||
|
|
||||||
hf_outputs = HFRunner.forward_generation_raw(
|
|
||||||
base_model=hf_model,
|
|
||||||
prompts=_PROMPTS,
|
|
||||||
max_new_tokens=_MAX_NEW_TOKENS,
|
|
||||||
tokenizer=hf_tokenizer,
|
|
||||||
lora_paths=None,
|
|
||||||
torch_dtype=_TORCH_DTYPE,
|
|
||||||
output_str_only=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if _ENABLE_UPDATE_WEIGHTS:
|
|
||||||
if tight_memory:
|
|
||||||
# If tight_memory is True, we need to move the model to CPU to save memory
|
|
||||||
hf_model.cpu()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# test update weights
|
|
||||||
print(f"subprocess[{tp_rank=}] get_fsdp_state_dict", flush=True)
|
|
||||||
fsdp_state_dict = _get_fsdp_state_dict(hf_model=hf_model, tp_size=tp_size)
|
|
||||||
|
|
||||||
engine = VerlEngine(
|
|
||||||
model_path=model_path,
|
|
||||||
load_format="dummy" if _ENABLE_UPDATE_WEIGHTS else "auto",
|
|
||||||
mem_fraction_static=mem_fraction_static,
|
|
||||||
random_seed=42,
|
|
||||||
trust_remote_code=True,
|
|
||||||
dtype=get_dtype_str(_TORCH_DTYPE),
|
|
||||||
device_mesh_cpu=inference_device_mesh_cpu["tp"],
|
|
||||||
backend="server",
|
|
||||||
enable_memory_saver=True,
|
|
||||||
port=PORT,
|
|
||||||
)
|
|
||||||
# test direct generate API with multiple different requests
|
|
||||||
print(
|
|
||||||
f"subprocess[{tp_rank=}] testing direct generate API with multiple requests"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Request 1: Basic generation with temperature
|
|
||||||
print(f"subprocess[{tp_rank=}] test request 1: Basic generation")
|
|
||||||
direct_response = engine.generate(
|
|
||||||
prompt="Hello, world!",
|
|
||||||
sampling_params={"temperature": 0.7, "max_new_tokens": 20},
|
|
||||||
)
|
|
||||||
print(f"Response 1: {direct_response}")
|
|
||||||
|
|
||||||
# Request 2: Zero temperature (greedy) generation
|
|
||||||
print(f"subprocess[{tp_rank=}] test request 2: Greedy generation")
|
|
||||||
direct_response = engine.generate(
|
|
||||||
prompt="Complete this sequence: 1, 2, 3,",
|
|
||||||
sampling_params={"temperature": 0.0, "max_new_tokens": 10},
|
|
||||||
)
|
|
||||||
print(f"Response 2: {direct_response}")
|
|
||||||
|
|
||||||
# Request 3: Batch generation
|
|
||||||
print(f"subprocess[{tp_rank=}] test request 3: Batch generation")
|
|
||||||
batch_response = engine.generate(
|
|
||||||
prompt=["Translate 'hello' to French:", "Translate 'goodbye' to Spanish:"],
|
|
||||||
sampling_params={"temperature": 0.8, "max_new_tokens": 15},
|
|
||||||
)
|
|
||||||
print(f"Response 3: {batch_response}")
|
|
||||||
|
|
||||||
# test memory occupation APIs
|
|
||||||
print(f"subprocess[{tp_rank=}] testing memory occupation APIs")
|
|
||||||
engine.release_memory_occupation()
|
|
||||||
print("Memory released")
|
|
||||||
# time.sleep(1)
|
|
||||||
engine.resume_memory_occupation()
|
|
||||||
print("Memory resumed")
|
|
||||||
|
|
||||||
# openai API test for reference
|
|
||||||
torch.distributed.barrier()
|
|
||||||
if tp_rank == 0:
|
|
||||||
client = OpenAI(api_key="None", base_url=f"http://localhost:{PORT}/v1")
|
|
||||||
print(client.models.list().data[0].id)
|
|
||||||
|
|
||||||
# Multiple HTTP API requests
|
|
||||||
print("Testing HTTP API with multiple requests")
|
|
||||||
|
|
||||||
# Request 1
|
|
||||||
url = f"http://localhost:{PORT}/generate"
|
|
||||||
data = {"text": "1*1=1, 1*2=2, 1*3=3, 1*4=4, 1*5="}
|
|
||||||
response = requests.post(url, json=data)
|
|
||||||
print(f"HTTP Response 1: {response.json()}")
|
|
||||||
|
|
||||||
# Request 2
|
|
||||||
data = {
|
|
||||||
"text": "The capital of France is",
|
|
||||||
"sampling_params": {"temperature": 0.2},
|
|
||||||
}
|
|
||||||
response = requests.post(url, json=data)
|
|
||||||
print(f"HTTP Response 2: {response.json()}")
|
|
||||||
|
|
||||||
# Request 3
|
|
||||||
data = {
|
|
||||||
"text": "List three colors:",
|
|
||||||
"sampling_params": {"top_p": 0.95, "max_new_tokens": 25},
|
|
||||||
}
|
|
||||||
response = requests.post(url, json=data)
|
|
||||||
print(f"HTTP Response 3: {response.json()}")
|
|
||||||
|
|
||||||
if _ENABLE_UPDATE_WEIGHTS:
|
|
||||||
print(f"subprocess[{tp_rank=}] call update_weights_from_tensor", flush=True)
|
|
||||||
|
|
||||||
engine.update_weights_from_tensor(
|
|
||||||
[(k, v) for k, v in fsdp_state_dict.items()]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Final generation test after weight update
|
|
||||||
print(f"subprocess[{tp_rank=}] testing generation after weight update")
|
|
||||||
direct_response = engine.generate(
|
|
||||||
prompt="After weight update: Hello, world!",
|
|
||||||
sampling_params={"temperature": 0.7, "max_new_tokens": 20},
|
|
||||||
)
|
|
||||||
print(f"Post-update response: {direct_response}")
|
|
||||||
|
|
||||||
execution_ok = True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"subprocess[{tp_rank=}] has error: {e}", flush=True)
|
|
||||||
traceback.print_exc()
|
|
||||||
execution_ok = False
|
|
||||||
|
|
||||||
output_writer.send(execution_ok)
|
|
||||||
output_writer.close()
|
|
||||||
|
|
||||||
engine.shutdown()
|
|
||||||
print(f"subprocess[{tp_rank=}] end", flush=True)
|
|
||||||
|
|
||||||
|
|
||||||
# Adapted from https://github.com/volcengine/verl/blob/main/tests/rollout/run_fsdp_vllm.py
|
|
||||||
def _get_fsdp_state_dict(hf_model, tp_size: int):
|
|
||||||
"""
|
|
||||||
Creates a sharded state dictionary for weight update testing.
|
|
||||||
|
|
||||||
Wraps the HuggingFace model with FSDP (FullyShardedDataParallel),
|
|
||||||
configures precision settings, and returns a sharded state dict
|
|
||||||
for testing VerlEngine's weight update capabilities.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
hf_model - HuggingFace model to wrap
|
|
||||||
tp_size: int - Number of tensor-parallel shards
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict - Sharded state dict for update_weights_from_tensor
|
|
||||||
"""
|
|
||||||
device_mesh = init_device_mesh(
|
|
||||||
"cuda", mesh_shape=(tp_size,), mesh_dim_names=["fsdp"]
|
|
||||||
)
|
|
||||||
|
|
||||||
mixed_precision = MixedPrecision(
|
|
||||||
param_dtype=torch.bfloat16,
|
|
||||||
reduce_dtype=torch.float32,
|
|
||||||
buffer_dtype=torch.float32,
|
|
||||||
)
|
|
||||||
fsdp_model = FSDP(
|
|
||||||
hf_model,
|
|
||||||
use_orig_params=True,
|
|
||||||
auto_wrap_policy=None,
|
|
||||||
device_id=torch.cuda.current_device(),
|
|
||||||
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
|
||||||
mixed_precision=mixed_precision,
|
|
||||||
cpu_offload=CPUOffload(offload_params=False),
|
|
||||||
sync_module_states=False,
|
|
||||||
device_mesh=device_mesh,
|
|
||||||
)
|
|
||||||
print(f"{fsdp_model=}")
|
|
||||||
|
|
||||||
FSDP.set_state_dict_type(
|
|
||||||
fsdp_model,
|
|
||||||
state_dict_type=StateDictType.SHARDED_STATE_DICT,
|
|
||||||
state_dict_config=ShardedStateDictConfig(),
|
|
||||||
)
|
|
||||||
|
|
||||||
return fsdp_model.state_dict()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
Reference in New Issue
Block a user