[Feature] SPMD for SGLang + Verl (#3852)
This commit is contained in:
6
.github/workflows/pr-test.yml
vendored
6
.github/workflows/pr-test.yml
vendored
@@ -149,6 +149,12 @@ jobs:
|
|||||||
cd test/srt
|
cd test/srt
|
||||||
python3 test_update_weights_from_distributed.py
|
python3 test_update_weights_from_distributed.py
|
||||||
|
|
||||||
|
- name: Test VerlEngine
|
||||||
|
timeout-minutes: 10
|
||||||
|
run: |
|
||||||
|
cd test/srt
|
||||||
|
python3 test_verl_engine.py
|
||||||
|
|
||||||
- name: Test expert parallelism (EP=2)
|
- name: Test expert parallelism (EP=2)
|
||||||
timeout-minutes: 10
|
timeout-minutes: 10
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
81
examples/runtime/engine/offline_batch_inference_torchrun.py
Normal file
81
examples/runtime/engine/offline_batch_inference_torchrun.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from torch.distributed.device_mesh import init_device_mesh
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.verl_engine import VerlEngine
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
"""
|
||||||
|
Example command:
|
||||||
|
```
|
||||||
|
torchrun --nproc_per_node=8 offline_batch_inference_torchrun.py
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
local_rank = int(os.environ["LOCAL_RANK"])
|
||||||
|
rank = int(os.environ["RANK"])
|
||||||
|
world_size = int(os.environ["WORLD_SIZE"])
|
||||||
|
|
||||||
|
def _log(text):
|
||||||
|
t = datetime.datetime.now().strftime("%H:%M:%S")
|
||||||
|
print(f"[{t}] [rank={rank}] {text}")
|
||||||
|
|
||||||
|
_log(
|
||||||
|
f'start {local_rank=} {rank=} {world_size=} {sys.argv=} {os.environ.get("CUDA_VISIBLE_DEVICES")}'
|
||||||
|
)
|
||||||
|
|
||||||
|
tp_size = 4
|
||||||
|
dp_size = 2
|
||||||
|
assert world_size == tp_size * dp_size
|
||||||
|
|
||||||
|
device_mesh_kwargs = dict(
|
||||||
|
mesh_shape=(tp_size, dp_size, 1), mesh_dim_names=["tp", "dp", "pp"]
|
||||||
|
)
|
||||||
|
device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs)
|
||||||
|
_log(f"{device_mesh_cpu=}")
|
||||||
|
|
||||||
|
tp_rank = device_mesh_cpu.get_local_rank("tp")
|
||||||
|
dp_rank = device_mesh_cpu.get_local_rank("dp")
|
||||||
|
_log(f"{tp_rank=} {tp_size=} ; {dp_rank=} {dp_size=}")
|
||||||
|
|
||||||
|
model_name, mem_fraction_static = "meta-llama/Llama-3.2-1B-Instruct", 0.1
|
||||||
|
# model_name, mem_fraction_static = "meta-llama/Llama-3.1-70B-Instruct", 0.9 # test large models
|
||||||
|
# model_name, mem_fraction_static = "deepseek-ai/DeepSeek-V2-Lite", 0.8
|
||||||
|
|
||||||
|
for k in ["TORCHELASTIC_USE_AGENT_STORE"]:
|
||||||
|
if k in os.environ:
|
||||||
|
del os.environ[k]
|
||||||
|
|
||||||
|
fragment = VerlEngine(
|
||||||
|
model_path=model_name,
|
||||||
|
mem_fraction_static=mem_fraction_static,
|
||||||
|
device_mesh_cpu=device_mesh_cpu["tp"],
|
||||||
|
base_gpu_id=dp_rank,
|
||||||
|
gpu_id_step=dp_size,
|
||||||
|
port=30000,
|
||||||
|
# for DeepSeek-V2-Lite + DP Attention
|
||||||
|
# enable_dp_attention=True, port=30000 + dp_rank * 100,
|
||||||
|
)
|
||||||
|
_log(f"{fragment=}")
|
||||||
|
|
||||||
|
prompt_all = [
|
||||||
|
["1+1=2, 1+2=3, 1+3=4, 1+4=", "9-1=8, 8-1=7, 7-1="],
|
||||||
|
["2*1=2, 2*2=4, 2*3=", "8/2=4, 6/2="],
|
||||||
|
]
|
||||||
|
prompt = prompt_all[dp_rank]
|
||||||
|
|
||||||
|
output = fragment.generate(
|
||||||
|
prompt=prompt,
|
||||||
|
sampling_params=dict(max_new_tokens=16, temperature=0.0),
|
||||||
|
)
|
||||||
|
_log(f"{prompt=} {output=}")
|
||||||
|
|
||||||
|
fragment.shutdown()
|
||||||
|
_log(f"End script")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
@@ -271,10 +271,18 @@ class Engine:
|
|||||||
self.tokenizer_manager.update_weights_from_distributed(obj, None)
|
self.tokenizer_manager.update_weights_from_distributed(obj, None)
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
|
def update_weights_from_tensor(
|
||||||
"""Update weights from distributed source."""
|
self,
|
||||||
|
named_tensors: List[Tuple[str, torch.Tensor]],
|
||||||
|
load_format: Optional[str] = None,
|
||||||
|
flush_cache: bool = True,
|
||||||
|
):
|
||||||
|
"""Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be true
|
||||||
|
to avoid duplicated operations such as clearing cache."""
|
||||||
obj = UpdateWeightsFromTensorReqInput(
|
obj = UpdateWeightsFromTensorReqInput(
|
||||||
serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors)
|
serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors),
|
||||||
|
load_format=load_format,
|
||||||
|
flush_cache=flush_cache,
|
||||||
)
|
)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
return loop.run_until_complete(
|
return loop.run_until_complete(
|
||||||
@@ -384,7 +392,10 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
|
|||||||
)
|
)
|
||||||
for tp_rank in tp_rank_range:
|
for tp_rank in tp_rank_range:
|
||||||
reader, writer = mp.Pipe(duplex=False)
|
reader, writer = mp.Pipe(duplex=False)
|
||||||
gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node
|
gpu_id = (
|
||||||
|
server_args.base_gpu_id
|
||||||
|
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
|
||||||
|
)
|
||||||
proc = mp.Process(
|
proc = mp.Process(
|
||||||
target=run_scheduler_process,
|
target=run_scheduler_process,
|
||||||
args=(server_args, port_args, gpu_id, tp_rank, None, writer),
|
args=(server_args, port_args, gpu_id, tp_rank, None, writer),
|
||||||
|
|||||||
145
python/sglang/srt/entrypoints/verl_engine.py
Normal file
145
python/sglang/srt/entrypoints/verl_engine.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
# Copyright 2023-2024 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
import os
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.distributed.tensor import DeviceMesh, DTensor
|
||||||
|
|
||||||
|
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
||||||
|
from sglang.srt.server import Engine
|
||||||
|
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj
|
||||||
|
|
||||||
|
|
||||||
|
class VerlEngine:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device_mesh_cpu: DeviceMesh,
|
||||||
|
nnodes: int = 1,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self._device_mesh_cpu = device_mesh_cpu
|
||||||
|
self._tp_rank = device_mesh_cpu.get_local_rank()
|
||||||
|
self._tp_size = device_mesh_cpu.size()
|
||||||
|
tp_size_per_node = self._tp_size // nnodes
|
||||||
|
node_rank = self._tp_rank // tp_size_per_node
|
||||||
|
first_rank_in_node = self._tp_rank % tp_size_per_node == 0
|
||||||
|
|
||||||
|
if first_rank_in_node:
|
||||||
|
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
|
||||||
|
self._engine = Engine(
|
||||||
|
**kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._engine = None
|
||||||
|
|
||||||
|
dist.barrier(group=self._device_mesh_cpu.get_group())
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||||
|
prompt: Optional[Union[List[str], str]] = None,
|
||||||
|
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
||||||
|
# The token ids for text; one can either specify text or input_ids.
|
||||||
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
||||||
|
# The image input. It can be a file name, a url, or base64 encoded string.
|
||||||
|
# See also python/sglang/srt/utils.py:load_image.
|
||||||
|
image_data: Optional[Union[List[str], str]] = None,
|
||||||
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||||
|
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||||
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||||
|
lora_path: Optional[List[Optional[str]]] = None,
|
||||||
|
custom_logit_processor: Optional[Union[List[str], str]] = None,
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
|
||||||
|
Please refer to `GenerateReqInput` for the documentation.
|
||||||
|
"""
|
||||||
|
if self._tp_rank == 0:
|
||||||
|
output = self._engine.generate(
|
||||||
|
prompt=prompt,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
input_ids=input_ids,
|
||||||
|
image_data=image_data,
|
||||||
|
return_logprob=return_logprob,
|
||||||
|
logprob_start_len=logprob_start_len,
|
||||||
|
top_logprobs_num=top_logprobs_num,
|
||||||
|
lora_path=lora_path,
|
||||||
|
custom_logit_processor=custom_logit_processor,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output = None
|
||||||
|
|
||||||
|
# Most naive implementation, can extract tensor and send via gloo if too slow
|
||||||
|
[output] = broadcast_pyobj(
|
||||||
|
data=[output],
|
||||||
|
rank=self._tp_rank,
|
||||||
|
dist_group=self._device_mesh_cpu.get_group(),
|
||||||
|
src=self._device_mesh_cpu.mesh[0].item(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def update_weights_from_tensor(
|
||||||
|
self,
|
||||||
|
named_tensors: List[Tuple[str, torch.Tensor]],
|
||||||
|
load_format: Optional[str] = None,
|
||||||
|
):
|
||||||
|
# Most naive implementation, can optimize a lot if it is bottleneck
|
||||||
|
for tensor_index, (name, tensor) in enumerate(named_tensors):
|
||||||
|
serialized_tensor = MultiprocessingSerializer.serialize(
|
||||||
|
_preprocess_tensor_for_update_weights(tensor)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._tp_rank == 0:
|
||||||
|
gathered_serialized_tensors = [None for _ in range(self._tp_size)]
|
||||||
|
else:
|
||||||
|
gathered_serialized_tensors = None
|
||||||
|
dist.gather_object(
|
||||||
|
obj=serialized_tensor,
|
||||||
|
object_gather_list=gathered_serialized_tensors,
|
||||||
|
dst=self._device_mesh_cpu.mesh.tolist()[0],
|
||||||
|
group=self._device_mesh_cpu.get_group(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._tp_rank == 0:
|
||||||
|
self._engine.update_weights_from_tensor(
|
||||||
|
named_tensors=[
|
||||||
|
(
|
||||||
|
name,
|
||||||
|
LocalSerializedTensor(values=gathered_serialized_tensors),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
load_format=load_format,
|
||||||
|
flush_cache=tensor_index == len(named_tensors) - 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def release_memory_occupation(self):
|
||||||
|
if self._tp_rank == 0:
|
||||||
|
self._engine.release_memory_occupation()
|
||||||
|
|
||||||
|
def resume_memory_occupation(self):
|
||||||
|
if self._tp_rank == 0:
|
||||||
|
self._engine.resume_memory_occupation()
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
if self._engine is not None:
|
||||||
|
self._engine.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
def _preprocess_tensor_for_update_weights(tensor: torch.Tensor):
|
||||||
|
if isinstance(tensor, DTensor):
|
||||||
|
return tensor.full_tensor()
|
||||||
|
return tensor
|
||||||
@@ -121,7 +121,7 @@ class DataParallelController:
|
|||||||
args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
|
args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
|
||||||
)
|
)
|
||||||
threads.append(thread)
|
threads.append(thread)
|
||||||
base_gpu_id += server_args.tp_size
|
base_gpu_id += server_args.tp_size * server_args.gpu_id_step
|
||||||
|
|
||||||
# Free all sockets before starting the threads to launch TP workers
|
# Free all sockets before starting the threads to launch TP workers
|
||||||
for sock in sockets:
|
for sock in sockets:
|
||||||
@@ -177,7 +177,11 @@ class DataParallelController:
|
|||||||
rank_port_args.nccl_port = port_args.nccl_port
|
rank_port_args.nccl_port = port_args.nccl_port
|
||||||
|
|
||||||
reader, writer = mp.Pipe(duplex=False)
|
reader, writer = mp.Pipe(duplex=False)
|
||||||
gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node
|
gpu_id = (
|
||||||
|
server_args.base_gpu_id
|
||||||
|
+ base_gpu_id
|
||||||
|
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
|
||||||
|
)
|
||||||
proc = mp.Process(
|
proc = mp.Process(
|
||||||
target=run_scheduler_process,
|
target=run_scheduler_process,
|
||||||
args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
|
args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
|
||||||
|
|||||||
@@ -449,6 +449,8 @@ class UpdateWeightsFromDistributedReqOutput:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class UpdateWeightsFromTensorReqInput:
|
class UpdateWeightsFromTensorReqInput:
|
||||||
serialized_named_tensors: bytes # indeed Dict[str, torch.Tensor]
|
serialized_named_tensors: bytes # indeed Dict[str, torch.Tensor]
|
||||||
|
load_format: Optional[str]
|
||||||
|
flush_cache: bool
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -1760,8 +1760,9 @@ class Scheduler:
|
|||||||
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
|
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
|
||||||
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
|
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
|
||||||
if success:
|
if success:
|
||||||
flash_cache_success = self.flush_cache()
|
if recv_req.flush_cache:
|
||||||
assert flash_cache_success, "Cache flush failed after updating weights"
|
flash_cache_success = self.flush_cache()
|
||||||
|
assert flash_cache_success, "Cache flush failed after updating weights"
|
||||||
else:
|
else:
|
||||||
logger.error(message)
|
logger.error(message)
|
||||||
return UpdateWeightsFromTensorReqOutput(success, message)
|
return UpdateWeightsFromTensorReqOutput(success, message)
|
||||||
|
|||||||
@@ -205,7 +205,10 @@ class TpModelWorker:
|
|||||||
|
|
||||||
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
||||||
success, message = self.model_runner.update_weights_from_tensor(
|
success, message = self.model_runner.update_weights_from_tensor(
|
||||||
MultiprocessingSerializer.deserialize(recv_req.serialized_named_tensors)
|
named_tensors=MultiprocessingSerializer.deserialize(
|
||||||
|
recv_req.serialized_named_tensors
|
||||||
|
),
|
||||||
|
load_format=recv_req.load_format,
|
||||||
)
|
)
|
||||||
return success, message
|
return success, message
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,8 @@ import gc
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import List, Optional, Tuple
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -56,10 +57,12 @@ from sglang.srt.mem_cache.memory_pool import (
|
|||||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader import get_model
|
from sglang.srt.model_loader import get_model
|
||||||
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
|
MultiprocessingSerializer,
|
||||||
enable_show_time_cost,
|
enable_show_time_cost,
|
||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
init_custom_process_group,
|
init_custom_process_group,
|
||||||
@@ -514,8 +517,21 @@ class ModelRunner:
|
|||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
return False, error_msg
|
return False, error_msg
|
||||||
|
|
||||||
def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
|
def update_weights_from_tensor(
|
||||||
self.model.load_weights(named_tensors)
|
self,
|
||||||
|
named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
|
||||||
|
load_format: Optional[str] = None,
|
||||||
|
):
|
||||||
|
named_tensors = [
|
||||||
|
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
|
||||||
|
for name, tensor in named_tensors
|
||||||
|
]
|
||||||
|
if load_format == "direct":
|
||||||
|
_model_load_weights_direct(self.model, named_tensors)
|
||||||
|
elif load_format is None:
|
||||||
|
self.model.load_weights(named_tensors)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown load_format={load_format}")
|
||||||
return True, "Success"
|
return True, "Success"
|
||||||
|
|
||||||
def get_weights_by_name(
|
def get_weights_by_name(
|
||||||
@@ -836,3 +852,26 @@ class ModelRunner:
|
|||||||
if rope_scaling is None:
|
if rope_scaling is None:
|
||||||
return False
|
return False
|
||||||
return rope_scaling.get("type", None) == "mrope"
|
return rope_scaling.get("type", None) == "mrope"
|
||||||
|
|
||||||
|
|
||||||
|
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
|
||||||
|
params_dict = dict(model.named_parameters())
|
||||||
|
for name, tensor in named_tensors:
|
||||||
|
default_weight_loader(params_dict[name], tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def _unwrap_tensor(tensor, tp_rank):
|
||||||
|
if isinstance(tensor, LocalSerializedTensor):
|
||||||
|
return tensor.get(tp_rank)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LocalSerializedTensor:
|
||||||
|
"""torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
|
||||||
|
The i-th element in the list corresponds to i-th rank's GPU."""
|
||||||
|
|
||||||
|
values: List[bytes]
|
||||||
|
|
||||||
|
def get(self, rank: int):
|
||||||
|
return MultiprocessingSerializer.deserialize(self.values[rank])
|
||||||
|
|||||||
@@ -336,12 +336,6 @@ class GemmaForCausalLM(nn.Module):
|
|||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
loaded_params.add(name)
|
loaded_params.add(name)
|
||||||
unloaded_params = params_dict.keys() - loaded_params
|
|
||||||
if unloaded_params:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Some weights are not initialized from checkpoints: "
|
|
||||||
f"{unloaded_params}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
EntryClass = GemmaForCausalLM
|
EntryClass = GemmaForCausalLM
|
||||||
|
|||||||
@@ -437,12 +437,5 @@ class Gemma2ForCausalLM(nn.Module):
|
|||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
loaded_params.add(name)
|
loaded_params.add(name)
|
||||||
|
|
||||||
unloaded_params = params_dict.keys() - loaded_params
|
|
||||||
if unloaded_params:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Some weights are not initialized from checkpoints: "
|
|
||||||
f"{unloaded_params}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
EntryClass = Gemma2ForCausalLM
|
EntryClass = Gemma2ForCausalLM
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ class ServerArgs:
|
|||||||
dist_timeout: Optional[int] = None # timeout for torch.distributed
|
dist_timeout: Optional[int] = None # timeout for torch.distributed
|
||||||
download_dir: Optional[str] = None
|
download_dir: Optional[str] = None
|
||||||
base_gpu_id: int = 0
|
base_gpu_id: int = 0
|
||||||
|
gpu_id_step: int = 1
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
log_level: str = "info"
|
log_level: str = "info"
|
||||||
@@ -552,6 +553,12 @@ class ServerArgs:
|
|||||||
default=ServerArgs.base_gpu_id,
|
default=ServerArgs.base_gpu_id,
|
||||||
help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
|
help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gpu-id-step",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.gpu_id_step,
|
||||||
|
help="The delta between consecutive GPU IDs that are used. For example, setting it to 2 will use GPU 0,2,4,...",
|
||||||
|
)
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -957,6 +964,7 @@ class ServerArgs:
|
|||||||
and (self.lora_paths is None or self.disable_radix_cache)
|
and (self.lora_paths is None or self.disable_radix_cache)
|
||||||
), "compatibility of lora and cuda graph and radix attention is in progress"
|
), "compatibility of lora and cuda graph and radix attention is in progress"
|
||||||
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
|
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
|
||||||
|
assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
|
||||||
|
|
||||||
if isinstance(self.lora_paths, list):
|
if isinstance(self.lora_paths, list):
|
||||||
lora_paths = self.lora_paths
|
lora_paths = self.lora_paths
|
||||||
|
|||||||
@@ -1386,7 +1386,6 @@ def get_ip() -> str:
|
|||||||
|
|
||||||
|
|
||||||
def get_open_port() -> int:
|
def get_open_port() -> int:
|
||||||
|
|
||||||
port = os.getenv("SGLANG_PORT")
|
port = os.getenv("SGLANG_PORT")
|
||||||
if port is not None:
|
if port is not None:
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
@@ -21,9 +21,9 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
from sglang.srt.entrypoints.engine import Engine
|
|
||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
|
from sglang.srt.server import Engine
|
||||||
|
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l
|
||||||
|
|
||||||
DEFAULT_PROMPTS = [
|
DEFAULT_PROMPTS = [
|
||||||
"Apple is red. Banana is Yellow. " * 800 + "Apple is",
|
"Apple is red. Banana is Yellow. " * 800 + "Apple is",
|
||||||
@@ -95,9 +95,11 @@ class HFRunner:
|
|||||||
torch_dtype: torch.dtype,
|
torch_dtype: torch.dtype,
|
||||||
model_type: str = "generation",
|
model_type: str = "generation",
|
||||||
output_str_only: bool = False,
|
output_str_only: bool = False,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.output_str_only = output_str_only
|
self.output_str_only = output_str_only
|
||||||
|
self.trust_remote_code = trust_remote_code
|
||||||
|
|
||||||
self.in_queue = mp.Queue()
|
self.in_queue = mp.Queue()
|
||||||
self.out_queue = mp.Queue()
|
self.out_queue = mp.Queue()
|
||||||
@@ -130,7 +132,7 @@ class HFRunner:
|
|||||||
self.base_model = AutoModelForCausalLM.from_pretrained(
|
self.base_model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_path,
|
model_path,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
trust_remote_code=False,
|
trust_remote_code=self.trust_remote_code,
|
||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=True,
|
||||||
).cuda()
|
).cuda()
|
||||||
elif self.model_type == "embedding":
|
elif self.model_type == "embedding":
|
||||||
@@ -147,7 +149,11 @@ class HFRunner:
|
|||||||
).cuda()
|
).cuda()
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unrecognized model type {self.model_type}")
|
raise Exception(f"Unrecognized model type {self.model_type}")
|
||||||
self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype)
|
self.tokenizer = get_tokenizer(
|
||||||
|
model_path,
|
||||||
|
torch_dtype=torch.dtype,
|
||||||
|
trust_remote_code=self.trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
# Run forward
|
# Run forward
|
||||||
while True:
|
while True:
|
||||||
@@ -157,74 +163,15 @@ class HFRunner:
|
|||||||
|
|
||||||
if prompts is not None:
|
if prompts is not None:
|
||||||
if self.model_type == "generation":
|
if self.model_type == "generation":
|
||||||
output_strs = []
|
|
||||||
top_input_logprobs = []
|
|
||||||
top_output_logprobs = []
|
|
||||||
for i, p in enumerate(prompts):
|
|
||||||
if isinstance(p, str):
|
|
||||||
input_ids = self.tokenizer.encode(
|
|
||||||
p, return_tensors="pt"
|
|
||||||
).cuda()
|
|
||||||
else:
|
|
||||||
input_ids = torch.tensor([p], device="cuda")
|
|
||||||
|
|
||||||
if lora_paths is not None and lora_paths[i] is not None:
|
|
||||||
from peft import PeftModel
|
|
||||||
|
|
||||||
self.model = PeftModel.from_pretrained(
|
|
||||||
self.base_model,
|
|
||||||
lora_paths[i],
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
is_trainable=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.model = self.base_model
|
|
||||||
|
|
||||||
outputs = self.model.generate(
|
|
||||||
input_ids,
|
|
||||||
do_sample=False,
|
|
||||||
temperature=None,
|
|
||||||
top_p=None,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
output_scores=(not self.output_str_only),
|
|
||||||
)
|
|
||||||
|
|
||||||
text = self.tokenizer.decode(
|
|
||||||
outputs[0][0][len(input_ids[0]) :], skip_special_tokens=True
|
|
||||||
)
|
|
||||||
# Check if the text is empty or only whitespace.
|
|
||||||
if not text.strip():
|
|
||||||
raise ValueError(
|
|
||||||
"Received an empty text response. Please verify your input or model configuration."
|
|
||||||
)
|
|
||||||
output_strs.append(text)
|
|
||||||
|
|
||||||
if not self.output_str_only:
|
|
||||||
# outputs.scores: (num_token, 1, vocab_size)
|
|
||||||
top_output_logprobs.append(
|
|
||||||
[
|
|
||||||
get_top_logprobs(
|
|
||||||
logits[0], NUM_TOP_LOGPROBS
|
|
||||||
).tolist()
|
|
||||||
for logits in outputs.scores
|
|
||||||
]
|
|
||||||
)
|
|
||||||
del outputs
|
|
||||||
|
|
||||||
input_logits = self.model.forward(input_ids).logits[0]
|
|
||||||
top_input_logprobs.append(
|
|
||||||
get_top_logprobs(
|
|
||||||
input_logits, NUM_TOP_LOGPROBS
|
|
||||||
).tolist()
|
|
||||||
)
|
|
||||||
del input_logits
|
|
||||||
|
|
||||||
out_queue.put(
|
out_queue.put(
|
||||||
ModelOutput(
|
self.forward_generation_raw(
|
||||||
output_strs=output_strs,
|
prompts=prompts,
|
||||||
top_input_logprobs=top_input_logprobs,
|
max_new_tokens=max_new_tokens,
|
||||||
top_output_logprobs=top_output_logprobs,
|
base_model=self.base_model,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
lora_paths=lora_paths,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
output_str_only=self.output_str_only,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -269,6 +216,79 @@ class HFRunner:
|
|||||||
self.model_proc.terminate()
|
self.model_proc.terminate()
|
||||||
self.in_queue = self.out_queue = None
|
self.in_queue = self.out_queue = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward_generation_raw(
|
||||||
|
prompts: Union[List[str], List[torch.Tensor]],
|
||||||
|
max_new_tokens,
|
||||||
|
base_model,
|
||||||
|
tokenizer,
|
||||||
|
lora_paths,
|
||||||
|
torch_dtype: torch.dtype,
|
||||||
|
output_str_only: bool,
|
||||||
|
) -> ModelOutput:
|
||||||
|
output_strs = []
|
||||||
|
top_input_logprobs = []
|
||||||
|
top_output_logprobs = []
|
||||||
|
for i, p in enumerate(prompts):
|
||||||
|
if isinstance(p, str):
|
||||||
|
input_ids = tokenizer.encode(p, return_tensors="pt").cuda()
|
||||||
|
else:
|
||||||
|
input_ids = torch.tensor([p], device="cuda")
|
||||||
|
|
||||||
|
if lora_paths is not None and lora_paths[i] is not None:
|
||||||
|
from peft import PeftModel
|
||||||
|
|
||||||
|
model = PeftModel.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
lora_paths[i],
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
is_trainable=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = base_model
|
||||||
|
|
||||||
|
outputs = model.generate(
|
||||||
|
input_ids,
|
||||||
|
do_sample=False,
|
||||||
|
temperature=None,
|
||||||
|
top_p=None,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
output_scores=(not output_str_only),
|
||||||
|
)
|
||||||
|
|
||||||
|
text = tokenizer.decode(
|
||||||
|
outputs[0][0][len(input_ids[0]) :], skip_special_tokens=True
|
||||||
|
)
|
||||||
|
# Check if the text is empty or only whitespace.
|
||||||
|
if not text.strip():
|
||||||
|
raise ValueError(
|
||||||
|
"Received an empty text response. Please verify your input or model configuration."
|
||||||
|
)
|
||||||
|
output_strs.append(text)
|
||||||
|
|
||||||
|
if not output_str_only:
|
||||||
|
# outputs.scores: (num_token, 1, vocab_size)
|
||||||
|
top_output_logprobs.append(
|
||||||
|
[
|
||||||
|
get_top_logprobs(logits[0], NUM_TOP_LOGPROBS).tolist()
|
||||||
|
for logits in outputs.scores
|
||||||
|
]
|
||||||
|
)
|
||||||
|
del outputs
|
||||||
|
|
||||||
|
input_logits = model.forward(input_ids).logits[0]
|
||||||
|
top_input_logprobs.append(
|
||||||
|
get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist()
|
||||||
|
)
|
||||||
|
del input_logits
|
||||||
|
|
||||||
|
return ModelOutput(
|
||||||
|
output_strs=output_strs,
|
||||||
|
top_input_logprobs=top_input_logprobs,
|
||||||
|
top_output_logprobs=top_output_logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SRTRunner:
|
class SRTRunner:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -284,6 +304,7 @@ class SRTRunner:
|
|||||||
disable_cuda_graph: bool = False,
|
disable_cuda_graph: bool = False,
|
||||||
disable_radix_cache: bool = False,
|
disable_radix_cache: bool = False,
|
||||||
mem_fraction_static: float = 0.65,
|
mem_fraction_static: float = 0.65,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.is_generation = model_type == "generation"
|
self.is_generation = model_type == "generation"
|
||||||
@@ -293,7 +314,7 @@ class SRTRunner:
|
|||||||
dtype=get_dtype_str(torch_dtype),
|
dtype=get_dtype_str(torch_dtype),
|
||||||
port=port,
|
port=port,
|
||||||
mem_fraction_static=mem_fraction_static,
|
mem_fraction_static=mem_fraction_static,
|
||||||
trust_remote_code=False,
|
trust_remote_code=trust_remote_code,
|
||||||
is_embedding=not self.is_generation,
|
is_embedding=not self.is_generation,
|
||||||
lora_paths=lora_paths,
|
lora_paths=lora_paths,
|
||||||
max_loras_per_batch=max_loras_per_batch,
|
max_loras_per_batch=max_loras_per_batch,
|
||||||
@@ -301,7 +322,7 @@ class SRTRunner:
|
|||||||
disable_cuda_graph=disable_cuda_graph,
|
disable_cuda_graph=disable_cuda_graph,
|
||||||
disable_radix_cache=disable_radix_cache,
|
disable_radix_cache=disable_radix_cache,
|
||||||
)
|
)
|
||||||
self.tokenizer = get_tokenizer(model_path)
|
self.tokenizer = get_tokenizer(model_path, trust_remote_code=trust_remote_code)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -310,54 +331,11 @@ class SRTRunner:
|
|||||||
lora_paths=None,
|
lora_paths=None,
|
||||||
):
|
):
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
# the return value contains logprobs from prefill
|
return self.forward_generation_raw(
|
||||||
output_strs = []
|
prompts=prompts,
|
||||||
top_input_logprobs = []
|
max_new_tokens=max_new_tokens,
|
||||||
top_output_logprobs = []
|
lora_paths=lora_paths,
|
||||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
engine=self.engine,
|
||||||
for i, prompt in enumerate(prompts):
|
|
||||||
response = self.engine.generate(
|
|
||||||
prompt,
|
|
||||||
lora_path=lora_paths[i] if lora_paths else None,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
return_logprob=True,
|
|
||||||
logprob_start_len=0,
|
|
||||||
top_logprobs_num=NUM_TOP_LOGPROBS,
|
|
||||||
)
|
|
||||||
text = response["text"]
|
|
||||||
|
|
||||||
# Check if the text is empty or only whitespace.
|
|
||||||
if not text.strip():
|
|
||||||
raise ValueError(
|
|
||||||
"Received an empty text response. Please verify your input or model configuration."
|
|
||||||
)
|
|
||||||
output_strs.append(text)
|
|
||||||
|
|
||||||
top_input_logprobs.append(
|
|
||||||
[
|
|
||||||
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
|
||||||
for x in response["meta_info"]["input_top_logprobs"][1:]
|
|
||||||
]
|
|
||||||
+ [
|
|
||||||
[
|
|
||||||
tup[0]
|
|
||||||
for tup in response["meta_info"]["output_top_logprobs"][0][
|
|
||||||
:NUM_TOP_LOGPROBS
|
|
||||||
]
|
|
||||||
]
|
|
||||||
]
|
|
||||||
)
|
|
||||||
top_output_logprobs.append(
|
|
||||||
[
|
|
||||||
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
|
||||||
for x in response["meta_info"]["output_top_logprobs"]
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return ModelOutput(
|
|
||||||
output_strs=output_strs,
|
|
||||||
top_input_logprobs=top_input_logprobs,
|
|
||||||
top_output_logprobs=top_output_logprobs,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = self.engine.encode(prompts)
|
response = self.engine.encode(prompts)
|
||||||
@@ -379,18 +357,11 @@ class SRTRunner:
|
|||||||
only return output strings and no logprobs
|
only return output strings and no logprobs
|
||||||
"""
|
"""
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
# the return value contains logprobs from prefill
|
return self.batch_forward_generation_raw(
|
||||||
output_strs = []
|
prompts=prompts,
|
||||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
max_new_tokens=max_new_tokens,
|
||||||
response = self.engine.generate(
|
lora_paths=lora_paths,
|
||||||
prompts,
|
engine=self.engine,
|
||||||
lora_path=lora_paths if lora_paths else None,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
)
|
|
||||||
output_strs = [r["text"] for r in response]
|
|
||||||
|
|
||||||
return ModelOutput(
|
|
||||||
output_strs=output_strs,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = self.engine.encode(prompts)
|
response = self.engine.encode(prompts)
|
||||||
@@ -408,6 +379,84 @@ class SRTRunner:
|
|||||||
self.engine.shutdown()
|
self.engine.shutdown()
|
||||||
del self.engine
|
del self.engine
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward_generation_raw(
|
||||||
|
prompts: Union[List[str], List[torch.Tensor]],
|
||||||
|
max_new_tokens,
|
||||||
|
lora_paths,
|
||||||
|
engine,
|
||||||
|
):
|
||||||
|
# the return value contains logprobs from prefill
|
||||||
|
output_strs = []
|
||||||
|
top_input_logprobs = []
|
||||||
|
top_output_logprobs = []
|
||||||
|
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||||
|
for i, prompt in enumerate(prompts):
|
||||||
|
response = engine.generate(
|
||||||
|
prompt,
|
||||||
|
lora_path=lora_paths[i] if lora_paths else None,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
return_logprob=True,
|
||||||
|
logprob_start_len=0,
|
||||||
|
top_logprobs_num=NUM_TOP_LOGPROBS,
|
||||||
|
)
|
||||||
|
text = response["text"]
|
||||||
|
|
||||||
|
# Check if the text is empty or only whitespace.
|
||||||
|
if not text.strip():
|
||||||
|
raise ValueError(
|
||||||
|
"Received an empty text response. Please verify your input or model configuration."
|
||||||
|
)
|
||||||
|
output_strs.append(text)
|
||||||
|
|
||||||
|
top_input_logprobs.append(
|
||||||
|
[
|
||||||
|
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
||||||
|
for x in response["meta_info"]["input_top_logprobs"][1:]
|
||||||
|
]
|
||||||
|
+ [
|
||||||
|
[
|
||||||
|
tup[0]
|
||||||
|
for tup in response["meta_info"]["output_top_logprobs"][0][
|
||||||
|
:NUM_TOP_LOGPROBS
|
||||||
|
]
|
||||||
|
]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
top_output_logprobs.append(
|
||||||
|
[
|
||||||
|
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
||||||
|
for x in response["meta_info"]["output_top_logprobs"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return ModelOutput(
|
||||||
|
output_strs=output_strs,
|
||||||
|
top_input_logprobs=top_input_logprobs,
|
||||||
|
top_output_logprobs=top_output_logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def batch_forward_generation_raw(
|
||||||
|
prompts: Union[List[str], List[torch.Tensor]],
|
||||||
|
max_new_tokens,
|
||||||
|
lora_paths,
|
||||||
|
engine,
|
||||||
|
):
|
||||||
|
# the return value contains logprobs from prefill
|
||||||
|
output_strs = []
|
||||||
|
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||||
|
response = engine.generate(
|
||||||
|
prompts,
|
||||||
|
lora_path=lora_paths if lora_paths else None,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
output_strs = [r["text"] for r in response]
|
||||||
|
|
||||||
|
return ModelOutput(
|
||||||
|
output_strs=output_strs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def monkey_patch_gemma2_sdpa():
|
def monkey_patch_gemma2_sdpa():
|
||||||
"""
|
"""
|
||||||
@@ -422,3 +471,52 @@ def monkey_patch_gemma2_sdpa():
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
setattr(Gemma2PreTrainedModel, "_check_and_enable_sdpa", _check_and_enable_sdpa)
|
setattr(Gemma2PreTrainedModel, "_check_and_enable_sdpa", _check_and_enable_sdpa)
|
||||||
|
|
||||||
|
|
||||||
|
def check_close_model_outputs(
|
||||||
|
hf_outputs: ModelOutput,
|
||||||
|
srt_outputs: ModelOutput,
|
||||||
|
prefill_tolerance: float,
|
||||||
|
decode_tolerance: float,
|
||||||
|
rouge_l_tolerance: float,
|
||||||
|
debug_text: str = "",
|
||||||
|
check_logprobs: bool = True,
|
||||||
|
):
|
||||||
|
# Compare output strings
|
||||||
|
print(f"{hf_outputs.output_strs=}")
|
||||||
|
print(f"{srt_outputs.output_strs=}")
|
||||||
|
rouge_l_scores = calculate_rouge_l(hf_outputs.output_strs, srt_outputs.output_strs)
|
||||||
|
print(f"{rouge_l_scores=}")
|
||||||
|
assert all(
|
||||||
|
score >= rouge_l_tolerance for score in rouge_l_scores
|
||||||
|
), f"Not all ROUGE-L scores are greater than rouge_l_tolerance={rouge_l_tolerance}"
|
||||||
|
|
||||||
|
if check_logprobs:
|
||||||
|
for i in range(len(hf_outputs.output_strs)):
|
||||||
|
# Compare input logprobs
|
||||||
|
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
|
||||||
|
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
|
||||||
|
input_len = hf_logprobs.shape[0]
|
||||||
|
print(
|
||||||
|
"prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
|
||||||
|
)
|
||||||
|
if input_len <= 100:
|
||||||
|
assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
|
||||||
|
f"prefill logprobs are not all close with {debug_text} "
|
||||||
|
f"prefill_tolerance={prefill_tolerance}."
|
||||||
|
f"{hf_logprobs=}, {srt_logprobs=}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compare output logprobs
|
||||||
|
hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
|
||||||
|
srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
|
||||||
|
|
||||||
|
print(
|
||||||
|
"decode logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
|
||||||
|
)
|
||||||
|
if input_len <= 100:
|
||||||
|
assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
|
||||||
|
f"decode logprobs are not all close with {debug_text} "
|
||||||
|
f"decode_tolerance={decode_tolerance}."
|
||||||
|
f"{hf_logprobs=}, {srt_logprobs=}"
|
||||||
|
)
|
||||||
|
|||||||
@@ -536,7 +536,7 @@ def test_hellaswag_select():
|
|||||||
# Compute accuracy
|
# Compute accuracy
|
||||||
accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
|
accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
|
||||||
print(f"{accuracy=}, {accuracy_gen=}")
|
print(f"{accuracy=}, {accuracy_gen=}")
|
||||||
assert np.abs(accuracy_gen - accuracy) < 0.05
|
assert np.abs(accuracy_gen - accuracy) < 0.1
|
||||||
assert np.abs(latency_gen - latency) < 1
|
assert np.abs(latency_gen - latency) < 1
|
||||||
|
|
||||||
return accuracy, latency
|
return accuracy, latency
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ class TestSRTBackend(unittest.TestCase):
|
|||||||
# Run twice to capture more bugs
|
# Run twice to capture more bugs
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
accuracy, latency = test_hellaswag_select()
|
accuracy, latency = test_hellaswag_select()
|
||||||
self.assertGreater(accuracy, 0.69)
|
self.assertGreater(accuracy, 0.65)
|
||||||
|
|
||||||
def test_gen_min_new_tokens(self):
|
def test_gen_min_new_tokens(self):
|
||||||
test_gen_min_new_tokens()
|
test_gen_min_new_tokens()
|
||||||
|
|||||||
@@ -27,8 +27,13 @@ from typing import List
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
from sglang.test.runners import (
|
||||||
from sglang.test.test_utils import calculate_rouge_l, is_in_ci
|
DEFAULT_PROMPTS,
|
||||||
|
HFRunner,
|
||||||
|
SRTRunner,
|
||||||
|
check_close_model_outputs,
|
||||||
|
)
|
||||||
|
from sglang.test.test_utils import is_in_ci
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@@ -39,6 +44,7 @@ class ModelCase:
|
|||||||
decode_tolerance: float = 5e-2
|
decode_tolerance: float = 5e-2
|
||||||
rouge_l_tolerance: float = 1
|
rouge_l_tolerance: float = 1
|
||||||
skip_long_prompt: bool = False
|
skip_long_prompt: bool = False
|
||||||
|
trust_remote_code: bool = False
|
||||||
|
|
||||||
|
|
||||||
# Popular models that run on the CI
|
# Popular models that run on the CI
|
||||||
@@ -53,7 +59,9 @@ ALL_OTHER_MODELS = [
|
|||||||
ModelCase("Qwen/Qwen2.5-14B-Instruct"),
|
ModelCase("Qwen/Qwen2.5-14B-Instruct"),
|
||||||
ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True),
|
ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True),
|
||||||
ModelCase("allenai/OLMo-1B-0724-hf", decode_tolerance=8e-2, skip_long_prompt=True),
|
ModelCase("allenai/OLMo-1B-0724-hf", decode_tolerance=8e-2, skip_long_prompt=True),
|
||||||
ModelCase("THUDM/glm-4-9b-chat"),
|
ModelCase(
|
||||||
|
"THUDM/glm-4-9b-chat", tp_size=2, trust_remote_code=True, skip_long_prompt=True
|
||||||
|
),
|
||||||
ModelCase("openai-community/gpt2"),
|
ModelCase("openai-community/gpt2"),
|
||||||
ModelCase("microsoft/Phi-3-small-8k-instruct"),
|
ModelCase("microsoft/Phi-3-small-8k-instruct"),
|
||||||
ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True),
|
ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True),
|
||||||
@@ -87,6 +95,7 @@ class TestGenerationModels(unittest.TestCase):
|
|||||||
model_path,
|
model_path,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
model_type="generation",
|
model_type="generation",
|
||||||
|
trust_remote_code=model_case.trust_remote_code,
|
||||||
) as hf_runner:
|
) as hf_runner:
|
||||||
hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
|
hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
|
||||||
|
|
||||||
@@ -95,48 +104,18 @@ class TestGenerationModels(unittest.TestCase):
|
|||||||
tp_size=model_case.tp_size,
|
tp_size=model_case.tp_size,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
model_type="generation",
|
model_type="generation",
|
||||||
|
trust_remote_code=model_case.trust_remote_code,
|
||||||
) as srt_runner:
|
) as srt_runner:
|
||||||
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
|
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
|
||||||
|
|
||||||
for i in range(len(prompts)):
|
check_close_model_outputs(
|
||||||
# Compare input logprobs
|
hf_outputs=hf_outputs,
|
||||||
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
|
srt_outputs=srt_outputs,
|
||||||
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
|
prefill_tolerance=model_case.prefill_tolerance,
|
||||||
input_len = hf_logprobs.shape[0]
|
decode_tolerance=model_case.decode_tolerance,
|
||||||
print(
|
rouge_l_tolerance=model_case.rouge_l_tolerance,
|
||||||
"prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
|
debug_text=f"model_path={model_path} prompts={prompts}",
|
||||||
)
|
|
||||||
if input_len <= 100:
|
|
||||||
assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
|
|
||||||
f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} "
|
|
||||||
f"prefill_tolerance={prefill_tolerance}."
|
|
||||||
f"{hf_logprobs=}, {srt_logprobs=}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compare output logprobs
|
|
||||||
hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
|
|
||||||
srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
|
|
||||||
|
|
||||||
print(
|
|
||||||
"decode logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
|
|
||||||
)
|
|
||||||
if input_len <= 100:
|
|
||||||
assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
|
|
||||||
f"decode logprobs are not all close with model_path={model_path} prompts={prompts} "
|
|
||||||
f"decode_tolerance={decode_tolerance}."
|
|
||||||
f"{hf_logprobs=}, {srt_logprobs=}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compare output strings
|
|
||||||
print(f"{hf_outputs.output_strs=}")
|
|
||||||
print(f"{srt_outputs.output_strs=}")
|
|
||||||
rouge_l_scores = calculate_rouge_l(
|
|
||||||
hf_outputs.output_strs, srt_outputs.output_strs
|
|
||||||
)
|
)
|
||||||
print(f"{rouge_l_scores=}")
|
|
||||||
assert all(
|
|
||||||
score >= rouge_l_tolerance for score in rouge_l_scores
|
|
||||||
), f"Not all ROUGE-L scores are greater than rouge_l_tolerance={rouge_l_tolerance}"
|
|
||||||
|
|
||||||
def test_ci_models(self):
|
def test_ci_models(self):
|
||||||
for model_case in CI_MODELS:
|
for model_case in CI_MODELS:
|
||||||
|
|||||||
@@ -26,6 +26,34 @@ class TestUpdateWeightsFromTensor(unittest.TestCase):
|
|||||||
|
|
||||||
engine.shutdown()
|
engine.shutdown()
|
||||||
|
|
||||||
|
def test_update_weights_from_tensor_load_format_direct(self):
|
||||||
|
engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||||
|
|
||||||
|
write_param_names = [
|
||||||
|
f"model.layers.{i}.self_attn.qkv_proj.weight" for i in range(6, 16)
|
||||||
|
]
|
||||||
|
read_param_names = [
|
||||||
|
f"model.layers.{i}.self_attn.k_proj.weight" for i in range(6, 16)
|
||||||
|
]
|
||||||
|
|
||||||
|
_check_param(
|
||||||
|
engine, read_param_names[0], [-0.0198, 0.0227, 0.0168, 0.0232, -0.0178]
|
||||||
|
)
|
||||||
|
|
||||||
|
new_tensor = torch.full((3072, 2048), 1.5)
|
||||||
|
engine.update_weights_from_tensor(
|
||||||
|
[
|
||||||
|
(write_param_name, new_tensor.clone())
|
||||||
|
for write_param_name in write_param_names
|
||||||
|
],
|
||||||
|
load_format="direct",
|
||||||
|
)
|
||||||
|
|
||||||
|
for read_param_name in read_param_names[:3]:
|
||||||
|
_check_param(engine, read_param_name, [1.5] * 5)
|
||||||
|
|
||||||
|
engine.shutdown()
|
||||||
|
|
||||||
|
|
||||||
def _check_param(engine, param_name, expect_values):
|
def _check_param(engine, param_name, expect_values):
|
||||||
actual_values = torch.tensor(engine.get_weights_by_name(param_name))[0, :5]
|
actual_values = torch.tensor(engine.get_weights_by_name(param_name))[0, :5]
|
||||||
|
|||||||
297
test/srt/test_verl_engine.py
Normal file
297
test/srt/test_verl_engine.py
Normal file
@@ -0,0 +1,297 @@
|
|||||||
|
import multiprocessing
|
||||||
|
import multiprocessing as mp
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import traceback
|
||||||
|
import unittest
|
||||||
|
from multiprocessing import Process
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.distributed.device_mesh import init_device_mesh
|
||||||
|
from torch.distributed.fsdp import CPUOffload
|
||||||
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
|
from torch.distributed.fsdp import MixedPrecision
|
||||||
|
from torch.distributed.fsdp.api import (
|
||||||
|
ShardedStateDictConfig,
|
||||||
|
ShardingStrategy,
|
||||||
|
StateDictType,
|
||||||
|
)
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.verl_engine import VerlEngine
|
||||||
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
|
from sglang.srt.utils import is_port_available
|
||||||
|
from sglang.test.runners import (
|
||||||
|
HFRunner,
|
||||||
|
SRTRunner,
|
||||||
|
check_close_model_outputs,
|
||||||
|
get_dtype_str,
|
||||||
|
)
|
||||||
|
from sglang.test.test_utils import is_in_ci
|
||||||
|
|
||||||
|
_MAX_NEW_TOKENS = 8
|
||||||
|
_PROMPTS = ["1+1=2, 1+2=3, 1+3=4, 1+4=5, 1+5=", "1*1=1, 1*2=2, 1*3=3, 1*4=4, 1*5="]
|
||||||
|
_TORCH_DTYPE = torch.float16
|
||||||
|
|
||||||
|
# Set to false to temporarily debug issues unrelated to weight update
|
||||||
|
_ENABLE_UPDATE_WEIGHTS = True
|
||||||
|
# _ENABLE_UPDATE_WEIGHTS = False
|
||||||
|
|
||||||
|
# TODO maybe we should add more other models? should we keep it in sync with test_generation_models.py?
|
||||||
|
CI_MODELS = [
|
||||||
|
dict(model_path="meta-llama/Llama-3.1-8B-Instruct"),
|
||||||
|
dict(model_path="google/gemma-2-2b"),
|
||||||
|
]
|
||||||
|
ALL_OTHER_MODELS = [
|
||||||
|
dict(model_path="meta-llama/Llama-3.2-1B-Instruct"),
|
||||||
|
dict(model_path="Qwen/Qwen2-1.5B"),
|
||||||
|
dict(
|
||||||
|
model_path="Qwen/Qwen2.5-14B-Instruct",
|
||||||
|
mem_fraction_static=0.4,
|
||||||
|
tp_size=8,
|
||||||
|
tight_memory=True,
|
||||||
|
decode_tolerance=1.3,
|
||||||
|
), # test_generation_models.py same config (qwen + tp=8) gives 1.22 decode error
|
||||||
|
dict(model_path="HuggingFaceTB/SmolLM-135M-Instruct", tp_size=3),
|
||||||
|
dict(model_path="allenai/OLMo-1B-0724-hf"),
|
||||||
|
dict(
|
||||||
|
model_path="THUDM/glm-4-9b-chat",
|
||||||
|
mem_fraction_static=0.1,
|
||||||
|
tp_size=8,
|
||||||
|
tight_memory=True,
|
||||||
|
),
|
||||||
|
dict(model_path="allenai/OLMo-2-1124-7B-Instruct"),
|
||||||
|
dict(
|
||||||
|
model_path="ibm-granite/granite-3.0-2b-instruct",
|
||||||
|
prefill_tolerance=0.22,
|
||||||
|
decode_tolerance=0.22,
|
||||||
|
),
|
||||||
|
# Fail to run these models in test_generation_models.py, need to fix that first
|
||||||
|
# dict(model_path="openai-community/gpt2"),
|
||||||
|
# dict(model_path="microsoft/Phi-3-small-8k-instruct"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestVerlEngine(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
multiprocessing.set_start_method("spawn")
|
||||||
|
|
||||||
|
def assert_fragment_e2e_execution(
|
||||||
|
self,
|
||||||
|
index: int,
|
||||||
|
model_path: str,
|
||||||
|
mem_fraction_static: float = 0.4,
|
||||||
|
tp_size: int = 2,
|
||||||
|
tight_memory: bool = False,
|
||||||
|
prefill_tolerance: float = 0.1,
|
||||||
|
decode_tolerance: float = 0.1,
|
||||||
|
):
|
||||||
|
master_port = find_available_port(23456)
|
||||||
|
|
||||||
|
print(f"assert_fragment_e2e_execution START {index=} {model_path=}")
|
||||||
|
|
||||||
|
processes = []
|
||||||
|
output_reader, output_writer = mp.Pipe(duplex=False)
|
||||||
|
for tp_rank in range(tp_size):
|
||||||
|
p = Process(
|
||||||
|
target=_run_subprocess,
|
||||||
|
kwargs=dict(
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
tp_size=tp_size,
|
||||||
|
master_port=master_port,
|
||||||
|
output_writer=output_writer,
|
||||||
|
model_path=model_path,
|
||||||
|
mem_fraction_static=mem_fraction_static,
|
||||||
|
tight_memory=tight_memory,
|
||||||
|
prefill_tolerance=prefill_tolerance,
|
||||||
|
decode_tolerance=decode_tolerance,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
p.start()
|
||||||
|
processes.append(p)
|
||||||
|
|
||||||
|
for _ in range(tp_size):
|
||||||
|
self.assertTrue(
|
||||||
|
output_reader.recv(),
|
||||||
|
f"Subprocess has error, please see logs above. ({index=} {model_path=})",
|
||||||
|
)
|
||||||
|
|
||||||
|
for p in processes:
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
def test_ci_models(self):
|
||||||
|
for index, model_info in enumerate(CI_MODELS):
|
||||||
|
self.assert_fragment_e2e_execution(index=index, **model_info)
|
||||||
|
|
||||||
|
def test_others(self):
|
||||||
|
if is_in_ci():
|
||||||
|
return
|
||||||
|
|
||||||
|
for index, model_info in enumerate(ALL_OTHER_MODELS):
|
||||||
|
self.assert_fragment_e2e_execution(index=index, **model_info)
|
||||||
|
|
||||||
|
# def test_adhoc(self):
|
||||||
|
# self.assert_fragment_e2e_execution(index=0, model_path="meta-llama/Llama-3.2-1B-Instruct")
|
||||||
|
|
||||||
|
|
||||||
|
def _run_subprocess(
|
||||||
|
tp_rank: int,
|
||||||
|
tp_size: int,
|
||||||
|
master_port: int,
|
||||||
|
output_writer,
|
||||||
|
model_path: str,
|
||||||
|
mem_fraction_static: float,
|
||||||
|
tight_memory: bool,
|
||||||
|
prefill_tolerance: float,
|
||||||
|
decode_tolerance: float,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
print(f"subprocess[{tp_rank=}] Start {os.environ.get('CUDA_VISIBLE_DEVICES')=}")
|
||||||
|
|
||||||
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
|
os.environ["MASTER_PORT"] = str(master_port)
|
||||||
|
torch.distributed.init_process_group(rank=tp_rank, world_size=tp_size)
|
||||||
|
torch.cuda.set_device(tp_rank)
|
||||||
|
|
||||||
|
mesh_kwargs = dict(mesh_shape=(tp_size, 1), mesh_dim_names=["tp", "pp"])
|
||||||
|
inference_device_mesh_device = init_device_mesh("cuda", **mesh_kwargs)
|
||||||
|
inference_device_mesh_cpu = init_device_mesh("cpu", **mesh_kwargs)
|
||||||
|
print(
|
||||||
|
f"subprocess[{tp_rank=}] {inference_device_mesh_device=} {inference_device_mesh_cpu=}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# hf model is used for comparison
|
||||||
|
hf_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path, torch_dtype=_TORCH_DTYPE, trust_remote_code=True
|
||||||
|
).cuda()
|
||||||
|
hf_tokenizer = get_tokenizer(model_path, trust_remote_code=True)
|
||||||
|
|
||||||
|
hf_outputs = HFRunner.forward_generation_raw(
|
||||||
|
prompts=_PROMPTS,
|
||||||
|
max_new_tokens=_MAX_NEW_TOKENS,
|
||||||
|
base_model=hf_model,
|
||||||
|
tokenizer=hf_tokenizer,
|
||||||
|
lora_paths=None,
|
||||||
|
torch_dtype=_TORCH_DTYPE,
|
||||||
|
output_str_only=False,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"subprocess[{tp_rank=}] call hf.forward {hf_outputs=}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if _ENABLE_UPDATE_WEIGHTS:
|
||||||
|
if tight_memory:
|
||||||
|
hf_model.cpu()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# test update weights
|
||||||
|
print(f"subprocess[{tp_rank=}] get_fsdp_state_dict", flush=True)
|
||||||
|
fsdp_state_dict = _get_fsdp_state_dict(hf_model=hf_model, tp_size=tp_size)
|
||||||
|
|
||||||
|
engine = VerlEngine(
|
||||||
|
model_path=model_path,
|
||||||
|
load_format="dummy" if _ENABLE_UPDATE_WEIGHTS else "auto",
|
||||||
|
mem_fraction_static=mem_fraction_static,
|
||||||
|
random_seed=42,
|
||||||
|
trust_remote_code=True,
|
||||||
|
dtype=get_dtype_str(_TORCH_DTYPE),
|
||||||
|
device_mesh_cpu=inference_device_mesh_cpu["tp"],
|
||||||
|
)
|
||||||
|
print(f"subprocess[{tp_rank=}] {engine=}", flush=True)
|
||||||
|
|
||||||
|
if _ENABLE_UPDATE_WEIGHTS:
|
||||||
|
print(f"subprocess[{tp_rank=}] call update_weights_from_tensor", flush=True)
|
||||||
|
engine.update_weights_from_tensor(
|
||||||
|
[(k, v) for k, v in fsdp_state_dict.items()]
|
||||||
|
)
|
||||||
|
|
||||||
|
for enable_batch in [False, True]:
|
||||||
|
if enable_batch:
|
||||||
|
fn = SRTRunner.batch_forward_generation_raw
|
||||||
|
else:
|
||||||
|
fn = SRTRunner.forward_generation_raw
|
||||||
|
|
||||||
|
srt_outputs = fn(
|
||||||
|
prompts=_PROMPTS,
|
||||||
|
max_new_tokens=_MAX_NEW_TOKENS,
|
||||||
|
lora_paths=None,
|
||||||
|
engine=engine,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"subprocess[{tp_rank=}] call srt.forward {enable_batch=} {srt_outputs=}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
check_close_model_outputs(
|
||||||
|
hf_outputs=hf_outputs,
|
||||||
|
srt_outputs=srt_outputs,
|
||||||
|
prefill_tolerance=prefill_tolerance,
|
||||||
|
decode_tolerance=decode_tolerance,
|
||||||
|
rouge_l_tolerance=1,
|
||||||
|
check_logprobs=not enable_batch,
|
||||||
|
debug_text=f"{enable_batch=} {tp_rank=}",
|
||||||
|
)
|
||||||
|
|
||||||
|
execution_ok = True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"subprocess[{tp_rank=}] has error: {e}", flush=True)
|
||||||
|
traceback.print_exc()
|
||||||
|
execution_ok = False
|
||||||
|
|
||||||
|
output_writer.send(execution_ok)
|
||||||
|
output_writer.close()
|
||||||
|
|
||||||
|
engine.shutdown()
|
||||||
|
print(f"subprocess[{tp_rank=}] end", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from https://github.com/volcengine/verl/blob/main/tests/rollout/run_fsdp_vllm.py
|
||||||
|
def _get_fsdp_state_dict(hf_model, tp_size: int):
|
||||||
|
device_mesh = init_device_mesh(
|
||||||
|
"cuda", mesh_shape=(tp_size,), mesh_dim_names=["fsdp"]
|
||||||
|
)
|
||||||
|
|
||||||
|
mixed_precision = MixedPrecision(
|
||||||
|
param_dtype=torch.bfloat16,
|
||||||
|
reduce_dtype=torch.float32,
|
||||||
|
buffer_dtype=torch.float32,
|
||||||
|
)
|
||||||
|
fsdp_model = FSDP(
|
||||||
|
hf_model,
|
||||||
|
use_orig_params=True,
|
||||||
|
auto_wrap_policy=None,
|
||||||
|
device_id=torch.cuda.current_device(),
|
||||||
|
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
||||||
|
mixed_precision=mixed_precision,
|
||||||
|
cpu_offload=CPUOffload(offload_params=False),
|
||||||
|
sync_module_states=False,
|
||||||
|
device_mesh=device_mesh,
|
||||||
|
)
|
||||||
|
print(f"{fsdp_model=}")
|
||||||
|
|
||||||
|
FSDP.set_state_dict_type(
|
||||||
|
fsdp_model,
|
||||||
|
state_dict_type=StateDictType.SHARDED_STATE_DICT,
|
||||||
|
state_dict_config=ShardedStateDictConfig(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return fsdp_model.state_dict()
|
||||||
|
|
||||||
|
|
||||||
|
# TODO Ask: this is extracted from PortArgs.init_new, is it allowed to extract it, i.e. touch that old code
|
||||||
|
def find_available_port(base_port: int):
|
||||||
|
port = base_port + random.randint(100, 1000)
|
||||||
|
while True:
|
||||||
|
if is_port_available(port):
|
||||||
|
return port
|
||||||
|
if port < 60000:
|
||||||
|
port += 42
|
||||||
|
else:
|
||||||
|
port -= 43
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user