[Feature] SPMD for SGLang + Verl (#3852)
This commit is contained in:
@@ -271,10 +271,18 @@ class Engine:
|
||||
self.tokenizer_manager.update_weights_from_distributed(obj, None)
|
||||
)
|
||||
|
||||
def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
|
||||
"""Update weights from distributed source."""
|
||||
def update_weights_from_tensor(
|
||||
self,
|
||||
named_tensors: List[Tuple[str, torch.Tensor]],
|
||||
load_format: Optional[str] = None,
|
||||
flush_cache: bool = True,
|
||||
):
|
||||
"""Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be true
|
||||
to avoid duplicated operations such as clearing cache."""
|
||||
obj = UpdateWeightsFromTensorReqInput(
|
||||
serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors)
|
||||
serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors),
|
||||
load_format=load_format,
|
||||
flush_cache=flush_cache,
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(
|
||||
@@ -384,7 +392,10 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
|
||||
)
|
||||
for tp_rank in tp_rank_range:
|
||||
reader, writer = mp.Pipe(duplex=False)
|
||||
gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node
|
||||
gpu_id = (
|
||||
server_args.base_gpu_id
|
||||
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
|
||||
)
|
||||
proc = mp.Process(
|
||||
target=run_scheduler_process,
|
||||
args=(server_args, port_args, gpu_id, tp_rank, None, writer),
|
||||
|
||||
145
python/sglang/srt/entrypoints/verl_engine.py
Normal file
145
python/sglang/srt/entrypoints/verl_engine.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.tensor import DeviceMesh, DTensor
|
||||
|
||||
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
||||
from sglang.srt.server import Engine
|
||||
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj
|
||||
|
||||
|
||||
class VerlEngine:
|
||||
def __init__(
|
||||
self,
|
||||
device_mesh_cpu: DeviceMesh,
|
||||
nnodes: int = 1,
|
||||
**kwargs,
|
||||
):
|
||||
self._device_mesh_cpu = device_mesh_cpu
|
||||
self._tp_rank = device_mesh_cpu.get_local_rank()
|
||||
self._tp_size = device_mesh_cpu.size()
|
||||
tp_size_per_node = self._tp_size // nnodes
|
||||
node_rank = self._tp_rank // tp_size_per_node
|
||||
first_rank_in_node = self._tp_rank % tp_size_per_node == 0
|
||||
|
||||
if first_rank_in_node:
|
||||
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
|
||||
self._engine = Engine(
|
||||
**kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes
|
||||
)
|
||||
else:
|
||||
self._engine = None
|
||||
|
||||
dist.barrier(group=self._device_mesh_cpu.get_group())
|
||||
|
||||
def generate(
|
||||
self,
|
||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||
prompt: Optional[Union[List[str], str]] = None,
|
||||
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
||||
# The image input. It can be a file name, a url, or base64 encoded string.
|
||||
# See also python/sglang/srt/utils.py:load_image.
|
||||
image_data: Optional[Union[List[str], str]] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
lora_path: Optional[List[Optional[str]]] = None,
|
||||
custom_logit_processor: Optional[Union[List[str], str]] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
|
||||
Please refer to `GenerateReqInput` for the documentation.
|
||||
"""
|
||||
if self._tp_rank == 0:
|
||||
output = self._engine.generate(
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
input_ids=input_ids,
|
||||
image_data=image_data,
|
||||
return_logprob=return_logprob,
|
||||
logprob_start_len=logprob_start_len,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
lora_path=lora_path,
|
||||
custom_logit_processor=custom_logit_processor,
|
||||
)
|
||||
else:
|
||||
output = None
|
||||
|
||||
# Most naive implementation, can extract tensor and send via gloo if too slow
|
||||
[output] = broadcast_pyobj(
|
||||
data=[output],
|
||||
rank=self._tp_rank,
|
||||
dist_group=self._device_mesh_cpu.get_group(),
|
||||
src=self._device_mesh_cpu.mesh[0].item(),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def update_weights_from_tensor(
|
||||
self,
|
||||
named_tensors: List[Tuple[str, torch.Tensor]],
|
||||
load_format: Optional[str] = None,
|
||||
):
|
||||
# Most naive implementation, can optimize a lot if it is bottleneck
|
||||
for tensor_index, (name, tensor) in enumerate(named_tensors):
|
||||
serialized_tensor = MultiprocessingSerializer.serialize(
|
||||
_preprocess_tensor_for_update_weights(tensor)
|
||||
)
|
||||
|
||||
if self._tp_rank == 0:
|
||||
gathered_serialized_tensors = [None for _ in range(self._tp_size)]
|
||||
else:
|
||||
gathered_serialized_tensors = None
|
||||
dist.gather_object(
|
||||
obj=serialized_tensor,
|
||||
object_gather_list=gathered_serialized_tensors,
|
||||
dst=self._device_mesh_cpu.mesh.tolist()[0],
|
||||
group=self._device_mesh_cpu.get_group(),
|
||||
)
|
||||
|
||||
if self._tp_rank == 0:
|
||||
self._engine.update_weights_from_tensor(
|
||||
named_tensors=[
|
||||
(
|
||||
name,
|
||||
LocalSerializedTensor(values=gathered_serialized_tensors),
|
||||
)
|
||||
],
|
||||
load_format=load_format,
|
||||
flush_cache=tensor_index == len(named_tensors) - 1,
|
||||
)
|
||||
|
||||
def release_memory_occupation(self):
|
||||
if self._tp_rank == 0:
|
||||
self._engine.release_memory_occupation()
|
||||
|
||||
def resume_memory_occupation(self):
|
||||
if self._tp_rank == 0:
|
||||
self._engine.resume_memory_occupation()
|
||||
|
||||
def shutdown(self):
|
||||
if self._engine is not None:
|
||||
self._engine.shutdown()
|
||||
|
||||
|
||||
def _preprocess_tensor_for_update_weights(tensor: torch.Tensor):
|
||||
if isinstance(tensor, DTensor):
|
||||
return tensor.full_tensor()
|
||||
return tensor
|
||||
@@ -121,7 +121,7 @@ class DataParallelController:
|
||||
args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
|
||||
)
|
||||
threads.append(thread)
|
||||
base_gpu_id += server_args.tp_size
|
||||
base_gpu_id += server_args.tp_size * server_args.gpu_id_step
|
||||
|
||||
# Free all sockets before starting the threads to launch TP workers
|
||||
for sock in sockets:
|
||||
@@ -177,7 +177,11 @@ class DataParallelController:
|
||||
rank_port_args.nccl_port = port_args.nccl_port
|
||||
|
||||
reader, writer = mp.Pipe(duplex=False)
|
||||
gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node
|
||||
gpu_id = (
|
||||
server_args.base_gpu_id
|
||||
+ base_gpu_id
|
||||
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
|
||||
)
|
||||
proc = mp.Process(
|
||||
target=run_scheduler_process,
|
||||
args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
|
||||
|
||||
@@ -449,6 +449,8 @@ class UpdateWeightsFromDistributedReqOutput:
|
||||
@dataclass
|
||||
class UpdateWeightsFromTensorReqInput:
|
||||
serialized_named_tensors: bytes # indeed Dict[str, torch.Tensor]
|
||||
load_format: Optional[str]
|
||||
flush_cache: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1760,8 +1760,9 @@ class Scheduler:
|
||||
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
|
||||
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
|
||||
if success:
|
||||
flash_cache_success = self.flush_cache()
|
||||
assert flash_cache_success, "Cache flush failed after updating weights"
|
||||
if recv_req.flush_cache:
|
||||
flash_cache_success = self.flush_cache()
|
||||
assert flash_cache_success, "Cache flush failed after updating weights"
|
||||
else:
|
||||
logger.error(message)
|
||||
return UpdateWeightsFromTensorReqOutput(success, message)
|
||||
|
||||
@@ -205,7 +205,10 @@ class TpModelWorker:
|
||||
|
||||
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
||||
success, message = self.model_runner.update_weights_from_tensor(
|
||||
MultiprocessingSerializer.deserialize(recv_req.serialized_named_tensors)
|
||||
named_tensors=MultiprocessingSerializer.deserialize(
|
||||
recv_req.serialized_named_tensors
|
||||
),
|
||||
load_format=recv_req.load_format,
|
||||
)
|
||||
return success, message
|
||||
|
||||
|
||||
@@ -17,7 +17,8 @@ import gc
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -56,10 +57,12 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader import get_model
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.srt.utils import (
|
||||
MultiprocessingSerializer,
|
||||
enable_show_time_cost,
|
||||
get_available_gpu_memory,
|
||||
init_custom_process_group,
|
||||
@@ -514,8 +517,21 @@ class ModelRunner:
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
|
||||
self.model.load_weights(named_tensors)
|
||||
def update_weights_from_tensor(
|
||||
self,
|
||||
named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
|
||||
load_format: Optional[str] = None,
|
||||
):
|
||||
named_tensors = [
|
||||
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
|
||||
for name, tensor in named_tensors
|
||||
]
|
||||
if load_format == "direct":
|
||||
_model_load_weights_direct(self.model, named_tensors)
|
||||
elif load_format is None:
|
||||
self.model.load_weights(named_tensors)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown load_format={load_format}")
|
||||
return True, "Success"
|
||||
|
||||
def get_weights_by_name(
|
||||
@@ -836,3 +852,26 @@ class ModelRunner:
|
||||
if rope_scaling is None:
|
||||
return False
|
||||
return rope_scaling.get("type", None) == "mrope"
|
||||
|
||||
|
||||
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(model.named_parameters())
|
||||
for name, tensor in named_tensors:
|
||||
default_weight_loader(params_dict[name], tensor)
|
||||
|
||||
|
||||
def _unwrap_tensor(tensor, tp_rank):
|
||||
if isinstance(tensor, LocalSerializedTensor):
|
||||
return tensor.get(tp_rank)
|
||||
return tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class LocalSerializedTensor:
|
||||
"""torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
|
||||
The i-th element in the list corresponds to i-th rank's GPU."""
|
||||
|
||||
values: List[bytes]
|
||||
|
||||
def get(self, rank: int):
|
||||
return MultiprocessingSerializer.deserialize(self.values[rank])
|
||||
|
||||
@@ -336,12 +336,6 @@ class GemmaForCausalLM(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
unloaded_params = params_dict.keys() - loaded_params
|
||||
if unloaded_params:
|
||||
raise RuntimeError(
|
||||
"Some weights are not initialized from checkpoints: "
|
||||
f"{unloaded_params}"
|
||||
)
|
||||
|
||||
|
||||
EntryClass = GemmaForCausalLM
|
||||
|
||||
@@ -437,12 +437,5 @@ class Gemma2ForCausalLM(nn.Module):
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
unloaded_params = params_dict.keys() - loaded_params
|
||||
if unloaded_params:
|
||||
raise RuntimeError(
|
||||
"Some weights are not initialized from checkpoints: "
|
||||
f"{unloaded_params}"
|
||||
)
|
||||
|
||||
|
||||
EntryClass = Gemma2ForCausalLM
|
||||
|
||||
@@ -82,6 +82,7 @@ class ServerArgs:
|
||||
dist_timeout: Optional[int] = None # timeout for torch.distributed
|
||||
download_dir: Optional[str] = None
|
||||
base_gpu_id: int = 0
|
||||
gpu_id_step: int = 1
|
||||
|
||||
# Logging
|
||||
log_level: str = "info"
|
||||
@@ -552,6 +553,12 @@ class ServerArgs:
|
||||
default=ServerArgs.base_gpu_id,
|
||||
help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-id-step",
|
||||
type=int,
|
||||
default=ServerArgs.gpu_id_step,
|
||||
help="The delta between consecutive GPU IDs that are used. For example, setting it to 2 will use GPU 0,2,4,...",
|
||||
)
|
||||
|
||||
# Logging
|
||||
parser.add_argument(
|
||||
@@ -957,6 +964,7 @@ class ServerArgs:
|
||||
and (self.lora_paths is None or self.disable_radix_cache)
|
||||
), "compatibility of lora and cuda graph and radix attention is in progress"
|
||||
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
|
||||
assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
|
||||
|
||||
if isinstance(self.lora_paths, list):
|
||||
lora_paths = self.lora_paths
|
||||
|
||||
@@ -1386,7 +1386,6 @@ def get_ip() -> str:
|
||||
|
||||
|
||||
def get_open_port() -> int:
|
||||
|
||||
port = os.getenv("SGLANG_PORT")
|
||||
if port is not None:
|
||||
while True:
|
||||
|
||||
@@ -21,9 +21,9 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from sglang.srt.entrypoints.engine import Engine
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
|
||||
from sglang.srt.server import Engine
|
||||
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l
|
||||
|
||||
DEFAULT_PROMPTS = [
|
||||
"Apple is red. Banana is Yellow. " * 800 + "Apple is",
|
||||
@@ -95,9 +95,11 @@ class HFRunner:
|
||||
torch_dtype: torch.dtype,
|
||||
model_type: str = "generation",
|
||||
output_str_only: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.model_type = model_type
|
||||
self.output_str_only = output_str_only
|
||||
self.trust_remote_code = trust_remote_code
|
||||
|
||||
self.in_queue = mp.Queue()
|
||||
self.out_queue = mp.Queue()
|
||||
@@ -130,7 +132,7 @@ class HFRunner:
|
||||
self.base_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=False,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
low_cpu_mem_usage=True,
|
||||
).cuda()
|
||||
elif self.model_type == "embedding":
|
||||
@@ -147,7 +149,11 @@ class HFRunner:
|
||||
).cuda()
|
||||
else:
|
||||
raise Exception(f"Unrecognized model type {self.model_type}")
|
||||
self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype)
|
||||
self.tokenizer = get_tokenizer(
|
||||
model_path,
|
||||
torch_dtype=torch.dtype,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
)
|
||||
|
||||
# Run forward
|
||||
while True:
|
||||
@@ -157,74 +163,15 @@ class HFRunner:
|
||||
|
||||
if prompts is not None:
|
||||
if self.model_type == "generation":
|
||||
output_strs = []
|
||||
top_input_logprobs = []
|
||||
top_output_logprobs = []
|
||||
for i, p in enumerate(prompts):
|
||||
if isinstance(p, str):
|
||||
input_ids = self.tokenizer.encode(
|
||||
p, return_tensors="pt"
|
||||
).cuda()
|
||||
else:
|
||||
input_ids = torch.tensor([p], device="cuda")
|
||||
|
||||
if lora_paths is not None and lora_paths[i] is not None:
|
||||
from peft import PeftModel
|
||||
|
||||
self.model = PeftModel.from_pretrained(
|
||||
self.base_model,
|
||||
lora_paths[i],
|
||||
torch_dtype=torch_dtype,
|
||||
is_trainable=False,
|
||||
)
|
||||
else:
|
||||
self.model = self.base_model
|
||||
|
||||
outputs = self.model.generate(
|
||||
input_ids,
|
||||
do_sample=False,
|
||||
temperature=None,
|
||||
top_p=None,
|
||||
max_new_tokens=max_new_tokens,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=(not self.output_str_only),
|
||||
)
|
||||
|
||||
text = self.tokenizer.decode(
|
||||
outputs[0][0][len(input_ids[0]) :], skip_special_tokens=True
|
||||
)
|
||||
# Check if the text is empty or only whitespace.
|
||||
if not text.strip():
|
||||
raise ValueError(
|
||||
"Received an empty text response. Please verify your input or model configuration."
|
||||
)
|
||||
output_strs.append(text)
|
||||
|
||||
if not self.output_str_only:
|
||||
# outputs.scores: (num_token, 1, vocab_size)
|
||||
top_output_logprobs.append(
|
||||
[
|
||||
get_top_logprobs(
|
||||
logits[0], NUM_TOP_LOGPROBS
|
||||
).tolist()
|
||||
for logits in outputs.scores
|
||||
]
|
||||
)
|
||||
del outputs
|
||||
|
||||
input_logits = self.model.forward(input_ids).logits[0]
|
||||
top_input_logprobs.append(
|
||||
get_top_logprobs(
|
||||
input_logits, NUM_TOP_LOGPROBS
|
||||
).tolist()
|
||||
)
|
||||
del input_logits
|
||||
|
||||
out_queue.put(
|
||||
ModelOutput(
|
||||
output_strs=output_strs,
|
||||
top_input_logprobs=top_input_logprobs,
|
||||
top_output_logprobs=top_output_logprobs,
|
||||
self.forward_generation_raw(
|
||||
prompts=prompts,
|
||||
max_new_tokens=max_new_tokens,
|
||||
base_model=self.base_model,
|
||||
tokenizer=self.tokenizer,
|
||||
lora_paths=lora_paths,
|
||||
torch_dtype=torch_dtype,
|
||||
output_str_only=self.output_str_only,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -269,6 +216,79 @@ class HFRunner:
|
||||
self.model_proc.terminate()
|
||||
self.in_queue = self.out_queue = None
|
||||
|
||||
@staticmethod
|
||||
def forward_generation_raw(
|
||||
prompts: Union[List[str], List[torch.Tensor]],
|
||||
max_new_tokens,
|
||||
base_model,
|
||||
tokenizer,
|
||||
lora_paths,
|
||||
torch_dtype: torch.dtype,
|
||||
output_str_only: bool,
|
||||
) -> ModelOutput:
|
||||
output_strs = []
|
||||
top_input_logprobs = []
|
||||
top_output_logprobs = []
|
||||
for i, p in enumerate(prompts):
|
||||
if isinstance(p, str):
|
||||
input_ids = tokenizer.encode(p, return_tensors="pt").cuda()
|
||||
else:
|
||||
input_ids = torch.tensor([p], device="cuda")
|
||||
|
||||
if lora_paths is not None and lora_paths[i] is not None:
|
||||
from peft import PeftModel
|
||||
|
||||
model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
lora_paths[i],
|
||||
torch_dtype=torch_dtype,
|
||||
is_trainable=False,
|
||||
)
|
||||
else:
|
||||
model = base_model
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids,
|
||||
do_sample=False,
|
||||
temperature=None,
|
||||
top_p=None,
|
||||
max_new_tokens=max_new_tokens,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=(not output_str_only),
|
||||
)
|
||||
|
||||
text = tokenizer.decode(
|
||||
outputs[0][0][len(input_ids[0]) :], skip_special_tokens=True
|
||||
)
|
||||
# Check if the text is empty or only whitespace.
|
||||
if not text.strip():
|
||||
raise ValueError(
|
||||
"Received an empty text response. Please verify your input or model configuration."
|
||||
)
|
||||
output_strs.append(text)
|
||||
|
||||
if not output_str_only:
|
||||
# outputs.scores: (num_token, 1, vocab_size)
|
||||
top_output_logprobs.append(
|
||||
[
|
||||
get_top_logprobs(logits[0], NUM_TOP_LOGPROBS).tolist()
|
||||
for logits in outputs.scores
|
||||
]
|
||||
)
|
||||
del outputs
|
||||
|
||||
input_logits = model.forward(input_ids).logits[0]
|
||||
top_input_logprobs.append(
|
||||
get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist()
|
||||
)
|
||||
del input_logits
|
||||
|
||||
return ModelOutput(
|
||||
output_strs=output_strs,
|
||||
top_input_logprobs=top_input_logprobs,
|
||||
top_output_logprobs=top_output_logprobs,
|
||||
)
|
||||
|
||||
|
||||
class SRTRunner:
|
||||
def __init__(
|
||||
@@ -284,6 +304,7 @@ class SRTRunner:
|
||||
disable_cuda_graph: bool = False,
|
||||
disable_radix_cache: bool = False,
|
||||
mem_fraction_static: float = 0.65,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.model_type = model_type
|
||||
self.is_generation = model_type == "generation"
|
||||
@@ -293,7 +314,7 @@ class SRTRunner:
|
||||
dtype=get_dtype_str(torch_dtype),
|
||||
port=port,
|
||||
mem_fraction_static=mem_fraction_static,
|
||||
trust_remote_code=False,
|
||||
trust_remote_code=trust_remote_code,
|
||||
is_embedding=not self.is_generation,
|
||||
lora_paths=lora_paths,
|
||||
max_loras_per_batch=max_loras_per_batch,
|
||||
@@ -301,7 +322,7 @@ class SRTRunner:
|
||||
disable_cuda_graph=disable_cuda_graph,
|
||||
disable_radix_cache=disable_radix_cache,
|
||||
)
|
||||
self.tokenizer = get_tokenizer(model_path)
|
||||
self.tokenizer = get_tokenizer(model_path, trust_remote_code=trust_remote_code)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -310,54 +331,11 @@ class SRTRunner:
|
||||
lora_paths=None,
|
||||
):
|
||||
if self.is_generation:
|
||||
# the return value contains logprobs from prefill
|
||||
output_strs = []
|
||||
top_input_logprobs = []
|
||||
top_output_logprobs = []
|
||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||
for i, prompt in enumerate(prompts):
|
||||
response = self.engine.generate(
|
||||
prompt,
|
||||
lora_path=lora_paths[i] if lora_paths else None,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=True,
|
||||
logprob_start_len=0,
|
||||
top_logprobs_num=NUM_TOP_LOGPROBS,
|
||||
)
|
||||
text = response["text"]
|
||||
|
||||
# Check if the text is empty or only whitespace.
|
||||
if not text.strip():
|
||||
raise ValueError(
|
||||
"Received an empty text response. Please verify your input or model configuration."
|
||||
)
|
||||
output_strs.append(text)
|
||||
|
||||
top_input_logprobs.append(
|
||||
[
|
||||
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
||||
for x in response["meta_info"]["input_top_logprobs"][1:]
|
||||
]
|
||||
+ [
|
||||
[
|
||||
tup[0]
|
||||
for tup in response["meta_info"]["output_top_logprobs"][0][
|
||||
:NUM_TOP_LOGPROBS
|
||||
]
|
||||
]
|
||||
]
|
||||
)
|
||||
top_output_logprobs.append(
|
||||
[
|
||||
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
||||
for x in response["meta_info"]["output_top_logprobs"]
|
||||
]
|
||||
)
|
||||
|
||||
return ModelOutput(
|
||||
output_strs=output_strs,
|
||||
top_input_logprobs=top_input_logprobs,
|
||||
top_output_logprobs=top_output_logprobs,
|
||||
return self.forward_generation_raw(
|
||||
prompts=prompts,
|
||||
max_new_tokens=max_new_tokens,
|
||||
lora_paths=lora_paths,
|
||||
engine=self.engine,
|
||||
)
|
||||
else:
|
||||
response = self.engine.encode(prompts)
|
||||
@@ -379,18 +357,11 @@ class SRTRunner:
|
||||
only return output strings and no logprobs
|
||||
"""
|
||||
if self.is_generation:
|
||||
# the return value contains logprobs from prefill
|
||||
output_strs = []
|
||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||
response = self.engine.generate(
|
||||
prompts,
|
||||
lora_path=lora_paths if lora_paths else None,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
output_strs = [r["text"] for r in response]
|
||||
|
||||
return ModelOutput(
|
||||
output_strs=output_strs,
|
||||
return self.batch_forward_generation_raw(
|
||||
prompts=prompts,
|
||||
max_new_tokens=max_new_tokens,
|
||||
lora_paths=lora_paths,
|
||||
engine=self.engine,
|
||||
)
|
||||
else:
|
||||
response = self.engine.encode(prompts)
|
||||
@@ -408,6 +379,84 @@ class SRTRunner:
|
||||
self.engine.shutdown()
|
||||
del self.engine
|
||||
|
||||
@staticmethod
|
||||
def forward_generation_raw(
|
||||
prompts: Union[List[str], List[torch.Tensor]],
|
||||
max_new_tokens,
|
||||
lora_paths,
|
||||
engine,
|
||||
):
|
||||
# the return value contains logprobs from prefill
|
||||
output_strs = []
|
||||
top_input_logprobs = []
|
||||
top_output_logprobs = []
|
||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||
for i, prompt in enumerate(prompts):
|
||||
response = engine.generate(
|
||||
prompt,
|
||||
lora_path=lora_paths[i] if lora_paths else None,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=True,
|
||||
logprob_start_len=0,
|
||||
top_logprobs_num=NUM_TOP_LOGPROBS,
|
||||
)
|
||||
text = response["text"]
|
||||
|
||||
# Check if the text is empty or only whitespace.
|
||||
if not text.strip():
|
||||
raise ValueError(
|
||||
"Received an empty text response. Please verify your input or model configuration."
|
||||
)
|
||||
output_strs.append(text)
|
||||
|
||||
top_input_logprobs.append(
|
||||
[
|
||||
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
||||
for x in response["meta_info"]["input_top_logprobs"][1:]
|
||||
]
|
||||
+ [
|
||||
[
|
||||
tup[0]
|
||||
for tup in response["meta_info"]["output_top_logprobs"][0][
|
||||
:NUM_TOP_LOGPROBS
|
||||
]
|
||||
]
|
||||
]
|
||||
)
|
||||
top_output_logprobs.append(
|
||||
[
|
||||
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
||||
for x in response["meta_info"]["output_top_logprobs"]
|
||||
]
|
||||
)
|
||||
|
||||
return ModelOutput(
|
||||
output_strs=output_strs,
|
||||
top_input_logprobs=top_input_logprobs,
|
||||
top_output_logprobs=top_output_logprobs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def batch_forward_generation_raw(
|
||||
prompts: Union[List[str], List[torch.Tensor]],
|
||||
max_new_tokens,
|
||||
lora_paths,
|
||||
engine,
|
||||
):
|
||||
# the return value contains logprobs from prefill
|
||||
output_strs = []
|
||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||
response = engine.generate(
|
||||
prompts,
|
||||
lora_path=lora_paths if lora_paths else None,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
output_strs = [r["text"] for r in response]
|
||||
|
||||
return ModelOutput(
|
||||
output_strs=output_strs,
|
||||
)
|
||||
|
||||
|
||||
def monkey_patch_gemma2_sdpa():
|
||||
"""
|
||||
@@ -422,3 +471,52 @@ def monkey_patch_gemma2_sdpa():
|
||||
return config
|
||||
|
||||
setattr(Gemma2PreTrainedModel, "_check_and_enable_sdpa", _check_and_enable_sdpa)
|
||||
|
||||
|
||||
def check_close_model_outputs(
|
||||
hf_outputs: ModelOutput,
|
||||
srt_outputs: ModelOutput,
|
||||
prefill_tolerance: float,
|
||||
decode_tolerance: float,
|
||||
rouge_l_tolerance: float,
|
||||
debug_text: str = "",
|
||||
check_logprobs: bool = True,
|
||||
):
|
||||
# Compare output strings
|
||||
print(f"{hf_outputs.output_strs=}")
|
||||
print(f"{srt_outputs.output_strs=}")
|
||||
rouge_l_scores = calculate_rouge_l(hf_outputs.output_strs, srt_outputs.output_strs)
|
||||
print(f"{rouge_l_scores=}")
|
||||
assert all(
|
||||
score >= rouge_l_tolerance for score in rouge_l_scores
|
||||
), f"Not all ROUGE-L scores are greater than rouge_l_tolerance={rouge_l_tolerance}"
|
||||
|
||||
if check_logprobs:
|
||||
for i in range(len(hf_outputs.output_strs)):
|
||||
# Compare input logprobs
|
||||
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
|
||||
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
|
||||
input_len = hf_logprobs.shape[0]
|
||||
print(
|
||||
"prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
|
||||
)
|
||||
if input_len <= 100:
|
||||
assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
|
||||
f"prefill logprobs are not all close with {debug_text} "
|
||||
f"prefill_tolerance={prefill_tolerance}."
|
||||
f"{hf_logprobs=}, {srt_logprobs=}"
|
||||
)
|
||||
|
||||
# Compare output logprobs
|
||||
hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
|
||||
srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
|
||||
|
||||
print(
|
||||
"decode logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
|
||||
)
|
||||
if input_len <= 100:
|
||||
assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
|
||||
f"decode logprobs are not all close with {debug_text} "
|
||||
f"decode_tolerance={decode_tolerance}."
|
||||
f"{hf_logprobs=}, {srt_logprobs=}"
|
||||
)
|
||||
|
||||
@@ -536,7 +536,7 @@ def test_hellaswag_select():
|
||||
# Compute accuracy
|
||||
accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
|
||||
print(f"{accuracy=}, {accuracy_gen=}")
|
||||
assert np.abs(accuracy_gen - accuracy) < 0.05
|
||||
assert np.abs(accuracy_gen - accuracy) < 0.1
|
||||
assert np.abs(latency_gen - latency) < 1
|
||||
|
||||
return accuracy, latency
|
||||
|
||||
Reference in New Issue
Block a user