[Feature] SPMD for SGLang + Verl (#3852)
This commit is contained in:
81
examples/runtime/engine/offline_batch_inference_torchrun.py
Normal file
81
examples/runtime/engine/offline_batch_inference_torchrun.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import datetime
|
||||
import os
|
||||
import sys
|
||||
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
|
||||
from sglang.srt.entrypoints.verl_engine import VerlEngine
|
||||
|
||||
|
||||
def run():
|
||||
"""
|
||||
Example command:
|
||||
```
|
||||
torchrun --nproc_per_node=8 offline_batch_inference_torchrun.py
|
||||
```
|
||||
"""
|
||||
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
def _log(text):
|
||||
t = datetime.datetime.now().strftime("%H:%M:%S")
|
||||
print(f"[{t}] [rank={rank}] {text}")
|
||||
|
||||
_log(
|
||||
f'start {local_rank=} {rank=} {world_size=} {sys.argv=} {os.environ.get("CUDA_VISIBLE_DEVICES")}'
|
||||
)
|
||||
|
||||
tp_size = 4
|
||||
dp_size = 2
|
||||
assert world_size == tp_size * dp_size
|
||||
|
||||
device_mesh_kwargs = dict(
|
||||
mesh_shape=(tp_size, dp_size, 1), mesh_dim_names=["tp", "dp", "pp"]
|
||||
)
|
||||
device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs)
|
||||
_log(f"{device_mesh_cpu=}")
|
||||
|
||||
tp_rank = device_mesh_cpu.get_local_rank("tp")
|
||||
dp_rank = device_mesh_cpu.get_local_rank("dp")
|
||||
_log(f"{tp_rank=} {tp_size=} ; {dp_rank=} {dp_size=}")
|
||||
|
||||
model_name, mem_fraction_static = "meta-llama/Llama-3.2-1B-Instruct", 0.1
|
||||
# model_name, mem_fraction_static = "meta-llama/Llama-3.1-70B-Instruct", 0.9 # test large models
|
||||
# model_name, mem_fraction_static = "deepseek-ai/DeepSeek-V2-Lite", 0.8
|
||||
|
||||
for k in ["TORCHELASTIC_USE_AGENT_STORE"]:
|
||||
if k in os.environ:
|
||||
del os.environ[k]
|
||||
|
||||
fragment = VerlEngine(
|
||||
model_path=model_name,
|
||||
mem_fraction_static=mem_fraction_static,
|
||||
device_mesh_cpu=device_mesh_cpu["tp"],
|
||||
base_gpu_id=dp_rank,
|
||||
gpu_id_step=dp_size,
|
||||
port=30000,
|
||||
# for DeepSeek-V2-Lite + DP Attention
|
||||
# enable_dp_attention=True, port=30000 + dp_rank * 100,
|
||||
)
|
||||
_log(f"{fragment=}")
|
||||
|
||||
prompt_all = [
|
||||
["1+1=2, 1+2=3, 1+3=4, 1+4=", "9-1=8, 8-1=7, 7-1="],
|
||||
["2*1=2, 2*2=4, 2*3=", "8/2=4, 6/2="],
|
||||
]
|
||||
prompt = prompt_all[dp_rank]
|
||||
|
||||
output = fragment.generate(
|
||||
prompt=prompt,
|
||||
sampling_params=dict(max_new_tokens=16, temperature=0.0),
|
||||
)
|
||||
_log(f"{prompt=} {output=}")
|
||||
|
||||
fragment.shutdown()
|
||||
_log(f"End script")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
Reference in New Issue
Block a user