[CI] update verlengine ci to 4-gpu test (#6007)
This commit is contained in:
2
.github/workflows/pr-test.yml
vendored
2
.github/workflows/pr-test.yml
vendored
@@ -103,7 +103,7 @@ jobs:
|
|||||||
bash scripts/ci_install_dependency.sh
|
bash scripts/ci_install_dependency.sh
|
||||||
|
|
||||||
- name: Run test
|
- name: Run test
|
||||||
timeout-minutes: 20
|
timeout-minutes: 30
|
||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 run_suite.py --suite per-commit-4-gpu
|
python3 run_suite.py --suite per-commit-4-gpu
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ suites = {
|
|||||||
TestFile("test_moe_ep.py", 181),
|
TestFile("test_moe_ep.py", 181),
|
||||||
TestFile("test_patch_torch.py", 19),
|
TestFile("test_patch_torch.py", 19),
|
||||||
TestFile("test_update_weights_from_distributed.py", 103),
|
TestFile("test_update_weights_from_distributed.py", 103),
|
||||||
TestFile("test_verl_engine.py", 64),
|
TestFile("test_verl_engine_2_gpu.py", 64),
|
||||||
],
|
],
|
||||||
"per-commit-2-gpu-amd": [
|
"per-commit-2-gpu-amd": [
|
||||||
TestFile("test_mla_tp.py", 170),
|
TestFile("test_mla_tp.py", 170),
|
||||||
@@ -109,6 +109,7 @@ suites = {
|
|||||||
"per-commit-4-gpu": [
|
"per-commit-4-gpu": [
|
||||||
TestFile("test_local_attn.py", 250),
|
TestFile("test_local_attn.py", 250),
|
||||||
TestFile("test_pp_single_node.py", 150),
|
TestFile("test_pp_single_node.py", 150),
|
||||||
|
TestFile("test_verl_engine_4_gpu.py", 64),
|
||||||
],
|
],
|
||||||
"per-commit-8-gpu": [
|
"per-commit-8-gpu": [
|
||||||
# Disabled deepep tests temporarily because it takes too much time.
|
# Disabled deepep tests temporarily because it takes too much time.
|
||||||
|
|||||||
276
test/srt/test_verl_engine_2_gpu.py
Normal file
276
test/srt/test_verl_engine_2_gpu.py
Normal file
@@ -0,0 +1,276 @@
|
|||||||
|
import multiprocessing
|
||||||
|
import multiprocessing as mp
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import traceback
|
||||||
|
import unittest
|
||||||
|
from multiprocessing import Process
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.distributed.device_mesh import init_device_mesh
|
||||||
|
from torch.distributed.fsdp import CPUOffload
|
||||||
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
|
from torch.distributed.fsdp import MixedPrecision
|
||||||
|
from torch.distributed.fsdp.api import (
|
||||||
|
ShardedStateDictConfig,
|
||||||
|
ShardingStrategy,
|
||||||
|
StateDictType,
|
||||||
|
)
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.verl_engine import VerlEngine
|
||||||
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
|
from sglang.srt.utils import is_port_available
|
||||||
|
from sglang.test.runners import (
|
||||||
|
HFRunner,
|
||||||
|
SRTRunner,
|
||||||
|
check_close_model_outputs,
|
||||||
|
get_dtype_str,
|
||||||
|
)
|
||||||
|
from sglang.test.test_utils import CustomTestCase, find_available_port, is_in_ci
|
||||||
|
|
||||||
|
_MAX_NEW_TOKENS = 8
|
||||||
|
_PROMPTS = ["1+1=2, 1+2=3, 1+3=4, 1+4=5, 1+5=", "1*1=1, 1*2=2, 1*3=3, 1*4=4, 1*5="]
|
||||||
|
_TORCH_DTYPE = torch.float16
|
||||||
|
|
||||||
|
# Set to false to temporarily debug issues unrelated to weight update
|
||||||
|
_ENABLE_UPDATE_WEIGHTS = True
|
||||||
|
# _ENABLE_UPDATE_WEIGHTS = False
|
||||||
|
|
||||||
|
# TODO maybe we should add more other models? should we keep it in sync with test_generation_models.py?
|
||||||
|
ALL_MODELS = [
|
||||||
|
dict(model_path="meta-llama/Llama-3.2-1B-Instruct"),
|
||||||
|
dict(model_path="Qwen/Qwen2-1.5B"),
|
||||||
|
dict(model_path="allenai/OLMo-1B-0724-hf"),
|
||||||
|
dict(model_path="allenai/OLMo-2-1124-7B-Instruct"),
|
||||||
|
dict(
|
||||||
|
model_path="ibm-granite/granite-3.0-2b-instruct",
|
||||||
|
prefill_tolerance=0.22,
|
||||||
|
decode_tolerance=0.22,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestVerlEngine(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
multiprocessing.set_start_method("spawn")
|
||||||
|
|
||||||
|
def assert_fragment_e2e_execution(
|
||||||
|
self,
|
||||||
|
index: int,
|
||||||
|
model_path: str,
|
||||||
|
mem_fraction_static: float = 0.4,
|
||||||
|
dp_size: int = 1,
|
||||||
|
tp_size: int = 2,
|
||||||
|
tight_memory: bool = False,
|
||||||
|
prefill_tolerance: float = 0.1,
|
||||||
|
decode_tolerance: float = 0.1,
|
||||||
|
):
|
||||||
|
master_port = find_available_port(23456)
|
||||||
|
|
||||||
|
print(f"assert_fragment_e2e_execution START {index=} {model_path=}")
|
||||||
|
|
||||||
|
processes = []
|
||||||
|
output_reader, output_writer = mp.Pipe(duplex=False)
|
||||||
|
world_size = dp_size * tp_size
|
||||||
|
for rank in range(world_size):
|
||||||
|
p = Process(
|
||||||
|
target=_run_subprocess,
|
||||||
|
kwargs=dict(
|
||||||
|
rank=rank,
|
||||||
|
dp_size=dp_size,
|
||||||
|
tp_size=tp_size,
|
||||||
|
master_port=master_port,
|
||||||
|
output_writer=output_writer,
|
||||||
|
model_path=model_path,
|
||||||
|
mem_fraction_static=mem_fraction_static,
|
||||||
|
tight_memory=tight_memory,
|
||||||
|
prefill_tolerance=prefill_tolerance,
|
||||||
|
decode_tolerance=decode_tolerance,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
p.start()
|
||||||
|
processes.append(p)
|
||||||
|
|
||||||
|
for _ in range(tp_size):
|
||||||
|
self.assertTrue(
|
||||||
|
output_reader.recv(),
|
||||||
|
f"Subprocess has error, please see logs above. ({index=} {model_path=})",
|
||||||
|
)
|
||||||
|
|
||||||
|
for p in processes:
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
def test_ci_models(self):
|
||||||
|
ci_models = [random.choice(ALL_MODELS)]
|
||||||
|
for index, model_info in enumerate(ci_models):
|
||||||
|
self.assert_fragment_e2e_execution(index=index, **model_info)
|
||||||
|
|
||||||
|
def test_others(self):
|
||||||
|
if is_in_ci():
|
||||||
|
return
|
||||||
|
|
||||||
|
for index, model_info in enumerate(ALL_MODELS):
|
||||||
|
self.assert_fragment_e2e_execution(index=index, **model_info)
|
||||||
|
|
||||||
|
# def test_adhoc(self):
|
||||||
|
# self.assert_fragment_e2e_execution(index=0, model_path="meta-llama/Llama-3.2-1B-Instruct")
|
||||||
|
|
||||||
|
|
||||||
|
def _run_subprocess(
|
||||||
|
rank: int,
|
||||||
|
dp_size: int,
|
||||||
|
tp_size: int,
|
||||||
|
master_port: int,
|
||||||
|
output_writer,
|
||||||
|
model_path: str,
|
||||||
|
mem_fraction_static: float,
|
||||||
|
tight_memory: bool,
|
||||||
|
prefill_tolerance: float,
|
||||||
|
decode_tolerance: float,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
print(f"subprocess[{rank=}] Start {os.environ.get('CUDA_VISIBLE_DEVICES')=}")
|
||||||
|
|
||||||
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
|
os.environ["MASTER_PORT"] = str(master_port)
|
||||||
|
torch.distributed.init_process_group(rank=rank, world_size=dp_size * tp_size)
|
||||||
|
torch.cuda.set_device(rank)
|
||||||
|
|
||||||
|
base_gpu_id = rank // tp_size * tp_size
|
||||||
|
|
||||||
|
mesh_kwargs = dict(
|
||||||
|
mesh_shape=(dp_size, tp_size, 1), mesh_dim_names=["dp", "tp", "pp"]
|
||||||
|
)
|
||||||
|
inference_device_mesh_device = init_device_mesh("cuda", **mesh_kwargs)
|
||||||
|
inference_device_mesh_cpu = init_device_mesh("cpu", **mesh_kwargs)
|
||||||
|
print(
|
||||||
|
f"subprocess[{rank=},{base_gpu_id=}] {inference_device_mesh_device=} {inference_device_mesh_cpu=}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# hf model is used for comparison
|
||||||
|
hf_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path, torch_dtype=_TORCH_DTYPE, trust_remote_code=True
|
||||||
|
).cuda()
|
||||||
|
hf_tokenizer = get_tokenizer(model_path, trust_remote_code=True)
|
||||||
|
|
||||||
|
hf_outputs = HFRunner.forward_generation_raw(
|
||||||
|
base_model=hf_model,
|
||||||
|
prompts=_PROMPTS,
|
||||||
|
max_new_tokens=_MAX_NEW_TOKENS,
|
||||||
|
tokenizer=hf_tokenizer,
|
||||||
|
lora_paths=None,
|
||||||
|
torch_dtype=_TORCH_DTYPE,
|
||||||
|
output_str_only=False,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"subprocess[{rank=}] call hf.forward {hf_outputs=}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if _ENABLE_UPDATE_WEIGHTS:
|
||||||
|
if tight_memory:
|
||||||
|
hf_model.cpu()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# test update weights
|
||||||
|
print(f"subprocess[{rank=}] get_fsdp_state_dict", flush=True)
|
||||||
|
fsdp_state_dict = _get_fsdp_state_dict(
|
||||||
|
hf_model=hf_model, world_size=dp_size * tp_size
|
||||||
|
)
|
||||||
|
|
||||||
|
engine = VerlEngine(
|
||||||
|
model_path=model_path,
|
||||||
|
load_format="dummy" if _ENABLE_UPDATE_WEIGHTS else "auto",
|
||||||
|
mem_fraction_static=mem_fraction_static,
|
||||||
|
random_seed=42,
|
||||||
|
base_gpu_id=base_gpu_id,
|
||||||
|
trust_remote_code=True,
|
||||||
|
dtype=get_dtype_str(_TORCH_DTYPE),
|
||||||
|
device_mesh_cpu=inference_device_mesh_cpu["tp"],
|
||||||
|
)
|
||||||
|
print(f"subprocess[{rank=}] {engine=}", flush=True)
|
||||||
|
|
||||||
|
if _ENABLE_UPDATE_WEIGHTS:
|
||||||
|
print(f"subprocess[{rank=}] call update_weights_from_tensor", flush=True)
|
||||||
|
engine.update_weights_from_tensor(
|
||||||
|
[(k, v) for k, v in fsdp_state_dict.items()]
|
||||||
|
)
|
||||||
|
|
||||||
|
for enable_batch in [False, True]:
|
||||||
|
if enable_batch:
|
||||||
|
fn = SRTRunner.batch_forward_generation_raw
|
||||||
|
else:
|
||||||
|
fn = SRTRunner.forward_generation_raw
|
||||||
|
|
||||||
|
srt_outputs = fn(
|
||||||
|
prompts=_PROMPTS,
|
||||||
|
max_new_tokens=_MAX_NEW_TOKENS,
|
||||||
|
lora_paths=None,
|
||||||
|
engine=engine,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"subprocess[{rank=}] call srt.forward {enable_batch=} {srt_outputs=}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
check_close_model_outputs(
|
||||||
|
hf_outputs=hf_outputs,
|
||||||
|
srt_outputs=srt_outputs,
|
||||||
|
prefill_tolerance=prefill_tolerance,
|
||||||
|
decode_tolerance=decode_tolerance,
|
||||||
|
rouge_l_tolerance=1,
|
||||||
|
check_logprobs=not enable_batch,
|
||||||
|
debug_text=f"{enable_batch=} {rank=}",
|
||||||
|
)
|
||||||
|
|
||||||
|
execution_ok = True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"subprocess[{rank=}] has error: {e}", flush=True)
|
||||||
|
traceback.print_exc()
|
||||||
|
execution_ok = False
|
||||||
|
|
||||||
|
output_writer.send(execution_ok)
|
||||||
|
output_writer.close()
|
||||||
|
|
||||||
|
engine.shutdown()
|
||||||
|
print(f"subprocess[{rank=}] end", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from https://github.com/volcengine/verl/blob/main/tests/rollout/run_fsdp_vllm.py
|
||||||
|
def _get_fsdp_state_dict(hf_model, world_size: int):
|
||||||
|
device_mesh = init_device_mesh(
|
||||||
|
"cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]
|
||||||
|
)
|
||||||
|
|
||||||
|
mixed_precision = MixedPrecision(
|
||||||
|
param_dtype=torch.bfloat16,
|
||||||
|
reduce_dtype=torch.float32,
|
||||||
|
buffer_dtype=torch.float32,
|
||||||
|
)
|
||||||
|
fsdp_model = FSDP(
|
||||||
|
hf_model,
|
||||||
|
use_orig_params=True,
|
||||||
|
auto_wrap_policy=None,
|
||||||
|
device_id=torch.cuda.current_device(),
|
||||||
|
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
||||||
|
mixed_precision=mixed_precision,
|
||||||
|
cpu_offload=CPUOffload(offload_params=False),
|
||||||
|
sync_module_states=False,
|
||||||
|
device_mesh=device_mesh,
|
||||||
|
)
|
||||||
|
print(f"{fsdp_model=}")
|
||||||
|
|
||||||
|
FSDP.set_state_dict_type(
|
||||||
|
fsdp_model,
|
||||||
|
state_dict_type=StateDictType.SHARDED_STATE_DICT,
|
||||||
|
state_dict_config=ShardedStateDictConfig(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return fsdp_model.state_dict()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -38,35 +38,27 @@ _ENABLE_UPDATE_WEIGHTS = True
|
|||||||
# _ENABLE_UPDATE_WEIGHTS = False
|
# _ENABLE_UPDATE_WEIGHTS = False
|
||||||
|
|
||||||
# TODO maybe we should add more other models? should we keep it in sync with test_generation_models.py?
|
# TODO maybe we should add more other models? should we keep it in sync with test_generation_models.py?
|
||||||
CI_MODELS = [
|
ALL_MODELS = [
|
||||||
dict(model_path="meta-llama/Llama-3.1-8B-Instruct"),
|
dict(
|
||||||
# Fail to run gemma-2-2b after transformers==4.48.3 -> 4.50.0
|
model_path="Qwen/Qwen2.5-0.5B",
|
||||||
# dict(model_path="google/gemma-2-2b"),
|
dp_size=2,
|
||||||
]
|
tp_size=2, # default to 2
|
||||||
ALL_OTHER_MODELS = [
|
),
|
||||||
dict(model_path="meta-llama/Llama-3.2-1B-Instruct"),
|
|
||||||
dict(model_path="Qwen/Qwen2-1.5B"),
|
|
||||||
dict(
|
dict(
|
||||||
model_path="Qwen/Qwen2.5-14B-Instruct",
|
model_path="Qwen/Qwen2.5-14B-Instruct",
|
||||||
mem_fraction_static=0.4,
|
mem_fraction_static=0.7,
|
||||||
tp_size=8,
|
dp_size=2,
|
||||||
|
tp_size=2,
|
||||||
tight_memory=True,
|
tight_memory=True,
|
||||||
decode_tolerance=1.3,
|
decode_tolerance=1.3,
|
||||||
), # test_generation_models.py same config (qwen + tp=8) gives 1.22 decode error
|
), # test_generation_models.py same config (qwen + tp=8) gives 1.22 decode error
|
||||||
dict(model_path="HuggingFaceTB/SmolLM-135M-Instruct", tp_size=3),
|
|
||||||
dict(model_path="allenai/OLMo-1B-0724-hf"),
|
|
||||||
dict(
|
dict(
|
||||||
model_path="THUDM/glm-4-9b-chat",
|
model_path="THUDM/glm-4-9b-chat",
|
||||||
mem_fraction_static=0.1,
|
mem_fraction_static=0.5,
|
||||||
tp_size=8,
|
dp_size=2,
|
||||||
|
tp_size=2,
|
||||||
tight_memory=True,
|
tight_memory=True,
|
||||||
),
|
),
|
||||||
dict(model_path="allenai/OLMo-2-1124-7B-Instruct"),
|
|
||||||
dict(
|
|
||||||
model_path="ibm-granite/granite-3.0-2b-instruct",
|
|
||||||
prefill_tolerance=0.22,
|
|
||||||
decode_tolerance=0.22,
|
|
||||||
),
|
|
||||||
# Fail to run these models in test_generation_models.py, need to fix that first
|
# Fail to run these models in test_generation_models.py, need to fix that first
|
||||||
# dict(model_path="openai-community/gpt2"),
|
# dict(model_path="openai-community/gpt2"),
|
||||||
# dict(model_path="microsoft/Phi-3-small-8k-instruct"),
|
# dict(model_path="microsoft/Phi-3-small-8k-instruct"),
|
||||||
@@ -83,6 +75,7 @@ class TestVerlEngine(CustomTestCase):
|
|||||||
index: int,
|
index: int,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
mem_fraction_static: float = 0.4,
|
mem_fraction_static: float = 0.4,
|
||||||
|
dp_size: int = 1,
|
||||||
tp_size: int = 2,
|
tp_size: int = 2,
|
||||||
tight_memory: bool = False,
|
tight_memory: bool = False,
|
||||||
prefill_tolerance: float = 0.1,
|
prefill_tolerance: float = 0.1,
|
||||||
@@ -94,11 +87,13 @@ class TestVerlEngine(CustomTestCase):
|
|||||||
|
|
||||||
processes = []
|
processes = []
|
||||||
output_reader, output_writer = mp.Pipe(duplex=False)
|
output_reader, output_writer = mp.Pipe(duplex=False)
|
||||||
for tp_rank in range(tp_size):
|
world_size = dp_size * tp_size
|
||||||
|
for rank in range(world_size):
|
||||||
p = Process(
|
p = Process(
|
||||||
target=_run_subprocess,
|
target=_run_subprocess,
|
||||||
kwargs=dict(
|
kwargs=dict(
|
||||||
tp_rank=tp_rank,
|
rank=rank,
|
||||||
|
dp_size=dp_size,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
master_port=master_port,
|
master_port=master_port,
|
||||||
output_writer=output_writer,
|
output_writer=output_writer,
|
||||||
@@ -122,7 +117,8 @@ class TestVerlEngine(CustomTestCase):
|
|||||||
p.join()
|
p.join()
|
||||||
|
|
||||||
def test_ci_models(self):
|
def test_ci_models(self):
|
||||||
for index, model_info in enumerate(CI_MODELS):
|
ci_models = [random.choice(ALL_MODELS)]
|
||||||
|
for index, model_info in enumerate(ci_models):
|
||||||
self.assert_fragment_e2e_execution(index=index, **model_info)
|
self.assert_fragment_e2e_execution(index=index, **model_info)
|
||||||
|
|
||||||
def test_others(self):
|
def test_others(self):
|
||||||
@@ -137,7 +133,8 @@ class TestVerlEngine(CustomTestCase):
|
|||||||
|
|
||||||
|
|
||||||
def _run_subprocess(
|
def _run_subprocess(
|
||||||
tp_rank: int,
|
rank: int,
|
||||||
|
dp_size: int,
|
||||||
tp_size: int,
|
tp_size: int,
|
||||||
master_port: int,
|
master_port: int,
|
||||||
output_writer,
|
output_writer,
|
||||||
@@ -148,18 +145,22 @@ def _run_subprocess(
|
|||||||
decode_tolerance: float,
|
decode_tolerance: float,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
print(f"subprocess[{tp_rank=}] Start {os.environ.get('CUDA_VISIBLE_DEVICES')=}")
|
print(f"subprocess[{rank=}] Start {os.environ.get('CUDA_VISIBLE_DEVICES')=}")
|
||||||
|
|
||||||
os.environ["MASTER_ADDR"] = "localhost"
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
os.environ["MASTER_PORT"] = str(master_port)
|
os.environ["MASTER_PORT"] = str(master_port)
|
||||||
torch.distributed.init_process_group(rank=tp_rank, world_size=tp_size)
|
torch.distributed.init_process_group(rank=rank, world_size=dp_size * tp_size)
|
||||||
torch.cuda.set_device(tp_rank)
|
torch.cuda.set_device(rank)
|
||||||
|
|
||||||
mesh_kwargs = dict(mesh_shape=(tp_size, 1), mesh_dim_names=["tp", "pp"])
|
base_gpu_id = rank // tp_size * tp_size
|
||||||
|
|
||||||
|
mesh_kwargs = dict(
|
||||||
|
mesh_shape=(dp_size, tp_size, 1), mesh_dim_names=["dp", "tp", "pp"]
|
||||||
|
)
|
||||||
inference_device_mesh_device = init_device_mesh("cuda", **mesh_kwargs)
|
inference_device_mesh_device = init_device_mesh("cuda", **mesh_kwargs)
|
||||||
inference_device_mesh_cpu = init_device_mesh("cpu", **mesh_kwargs)
|
inference_device_mesh_cpu = init_device_mesh("cpu", **mesh_kwargs)
|
||||||
print(
|
print(
|
||||||
f"subprocess[{tp_rank=}] {inference_device_mesh_device=} {inference_device_mesh_cpu=}"
|
f"subprocess[{rank=},{base_gpu_id=}] {inference_device_mesh_device=} {inference_device_mesh_cpu=}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# hf model is used for comparison
|
# hf model is used for comparison
|
||||||
@@ -178,7 +179,7 @@ def _run_subprocess(
|
|||||||
output_str_only=False,
|
output_str_only=False,
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f"subprocess[{tp_rank=}] call hf.forward {hf_outputs=}",
|
f"subprocess[{rank=}] call hf.forward {hf_outputs=}",
|
||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -188,22 +189,25 @@ def _run_subprocess(
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# test update weights
|
# test update weights
|
||||||
print(f"subprocess[{tp_rank=}] get_fsdp_state_dict", flush=True)
|
print(f"subprocess[{rank=}] get_fsdp_state_dict", flush=True)
|
||||||
fsdp_state_dict = _get_fsdp_state_dict(hf_model=hf_model, tp_size=tp_size)
|
fsdp_state_dict = _get_fsdp_state_dict(
|
||||||
|
hf_model=hf_model, world_size=dp_size * tp_size
|
||||||
|
)
|
||||||
|
|
||||||
engine = VerlEngine(
|
engine = VerlEngine(
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
load_format="dummy" if _ENABLE_UPDATE_WEIGHTS else "auto",
|
load_format="dummy" if _ENABLE_UPDATE_WEIGHTS else "auto",
|
||||||
mem_fraction_static=mem_fraction_static,
|
mem_fraction_static=mem_fraction_static,
|
||||||
random_seed=42,
|
random_seed=42,
|
||||||
|
base_gpu_id=base_gpu_id,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
dtype=get_dtype_str(_TORCH_DTYPE),
|
dtype=get_dtype_str(_TORCH_DTYPE),
|
||||||
device_mesh_cpu=inference_device_mesh_cpu["tp"],
|
device_mesh_cpu=inference_device_mesh_cpu["tp"],
|
||||||
)
|
)
|
||||||
print(f"subprocess[{tp_rank=}] {engine=}", flush=True)
|
print(f"subprocess[{rank=}] {engine=}", flush=True)
|
||||||
|
|
||||||
if _ENABLE_UPDATE_WEIGHTS:
|
if _ENABLE_UPDATE_WEIGHTS:
|
||||||
print(f"subprocess[{tp_rank=}] call update_weights_from_tensor", flush=True)
|
print(f"subprocess[{rank=}] call update_weights_from_tensor", flush=True)
|
||||||
engine.update_weights_from_tensor(
|
engine.update_weights_from_tensor(
|
||||||
[(k, v) for k, v in fsdp_state_dict.items()]
|
[(k, v) for k, v in fsdp_state_dict.items()]
|
||||||
)
|
)
|
||||||
@@ -221,7 +225,7 @@ def _run_subprocess(
|
|||||||
engine=engine,
|
engine=engine,
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f"subprocess[{tp_rank=}] call srt.forward {enable_batch=} {srt_outputs=}",
|
f"subprocess[{rank=}] call srt.forward {enable_batch=} {srt_outputs=}",
|
||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -232,13 +236,13 @@ def _run_subprocess(
|
|||||||
decode_tolerance=decode_tolerance,
|
decode_tolerance=decode_tolerance,
|
||||||
rouge_l_tolerance=1,
|
rouge_l_tolerance=1,
|
||||||
check_logprobs=not enable_batch,
|
check_logprobs=not enable_batch,
|
||||||
debug_text=f"{enable_batch=} {tp_rank=}",
|
debug_text=f"{enable_batch=} {rank=}",
|
||||||
)
|
)
|
||||||
|
|
||||||
execution_ok = True
|
execution_ok = True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"subprocess[{tp_rank=}] has error: {e}", flush=True)
|
print(f"subprocess[{rank=}] has error: {e}", flush=True)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
execution_ok = False
|
execution_ok = False
|
||||||
|
|
||||||
@@ -246,13 +250,13 @@ def _run_subprocess(
|
|||||||
output_writer.close()
|
output_writer.close()
|
||||||
|
|
||||||
engine.shutdown()
|
engine.shutdown()
|
||||||
print(f"subprocess[{tp_rank=}] end", flush=True)
|
print(f"subprocess[{rank=}] end", flush=True)
|
||||||
|
|
||||||
|
|
||||||
# Adapted from https://github.com/volcengine/verl/blob/main/tests/rollout/run_fsdp_vllm.py
|
# Adapted from https://github.com/volcengine/verl/blob/main/tests/rollout/run_fsdp_vllm.py
|
||||||
def _get_fsdp_state_dict(hf_model, tp_size: int):
|
def _get_fsdp_state_dict(hf_model, world_size: int):
|
||||||
device_mesh = init_device_mesh(
|
device_mesh = init_device_mesh(
|
||||||
"cuda", mesh_shape=(tp_size,), mesh_dim_names=["fsdp"]
|
"cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]
|
||||||
)
|
)
|
||||||
|
|
||||||
mixed_precision = MixedPrecision(
|
mixed_precision = MixedPrecision(
|
||||||
Reference in New Issue
Block a user