[Feature] SPMD for SGLang + Verl (#3852)

This commit is contained in:
fzyzcjy
2025-03-01 01:53:10 +08:00
committed by GitHub
parent bac414ab53
commit e3e0bc50a9
19 changed files with 890 additions and 202 deletions

View File

@@ -271,10 +271,18 @@ class Engine:
self.tokenizer_manager.update_weights_from_distributed(obj, None)
)
def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
"""Update weights from distributed source."""
def update_weights_from_tensor(
self,
named_tensors: List[Tuple[str, torch.Tensor]],
load_format: Optional[str] = None,
flush_cache: bool = True,
):
"""Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be true
to avoid duplicated operations such as clearing cache."""
obj = UpdateWeightsFromTensorReqInput(
serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors)
serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors),
load_format=load_format,
flush_cache=flush_cache,
)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
@@ -384,7 +392,10 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
)
for tp_rank in tp_rank_range:
reader, writer = mp.Pipe(duplex=False)
gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node
gpu_id = (
server_args.base_gpu_id
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
)
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, None, writer),

View File

@@ -0,0 +1,145 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from torch.distributed.tensor import DeviceMesh, DTensor
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
from sglang.srt.server import Engine
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj
class VerlEngine:
def __init__(
self,
device_mesh_cpu: DeviceMesh,
nnodes: int = 1,
**kwargs,
):
self._device_mesh_cpu = device_mesh_cpu
self._tp_rank = device_mesh_cpu.get_local_rank()
self._tp_size = device_mesh_cpu.size()
tp_size_per_node = self._tp_size // nnodes
node_rank = self._tp_rank // tp_size_per_node
first_rank_in_node = self._tp_rank % tp_size_per_node == 0
if first_rank_in_node:
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
self._engine = Engine(
**kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes
)
else:
self._engine = None
dist.barrier(group=self._device_mesh_cpu.get_group())
def generate(
self,
# The input prompt. It can be a single prompt or a batch of prompts.
prompt: Optional[Union[List[str], str]] = None,
sampling_params: Optional[Union[List[Dict], Dict]] = None,
# The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
lora_path: Optional[List[Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[str], str]] = None,
) -> Dict:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation.
"""
if self._tp_rank == 0:
output = self._engine.generate(
prompt=prompt,
sampling_params=sampling_params,
input_ids=input_ids,
image_data=image_data,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
lora_path=lora_path,
custom_logit_processor=custom_logit_processor,
)
else:
output = None
# Most naive implementation, can extract tensor and send via gloo if too slow
[output] = broadcast_pyobj(
data=[output],
rank=self._tp_rank,
dist_group=self._device_mesh_cpu.get_group(),
src=self._device_mesh_cpu.mesh[0].item(),
)
return output
def update_weights_from_tensor(
self,
named_tensors: List[Tuple[str, torch.Tensor]],
load_format: Optional[str] = None,
):
# Most naive implementation, can optimize a lot if it is bottleneck
for tensor_index, (name, tensor) in enumerate(named_tensors):
serialized_tensor = MultiprocessingSerializer.serialize(
_preprocess_tensor_for_update_weights(tensor)
)
if self._tp_rank == 0:
gathered_serialized_tensors = [None for _ in range(self._tp_size)]
else:
gathered_serialized_tensors = None
dist.gather_object(
obj=serialized_tensor,
object_gather_list=gathered_serialized_tensors,
dst=self._device_mesh_cpu.mesh.tolist()[0],
group=self._device_mesh_cpu.get_group(),
)
if self._tp_rank == 0:
self._engine.update_weights_from_tensor(
named_tensors=[
(
name,
LocalSerializedTensor(values=gathered_serialized_tensors),
)
],
load_format=load_format,
flush_cache=tensor_index == len(named_tensors) - 1,
)
def release_memory_occupation(self):
if self._tp_rank == 0:
self._engine.release_memory_occupation()
def resume_memory_occupation(self):
if self._tp_rank == 0:
self._engine.resume_memory_occupation()
def shutdown(self):
if self._engine is not None:
self._engine.shutdown()
def _preprocess_tensor_for_update_weights(tensor: torch.Tensor):
if isinstance(tensor, DTensor):
return tensor.full_tensor()
return tensor

View File

@@ -121,7 +121,7 @@ class DataParallelController:
args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
)
threads.append(thread)
base_gpu_id += server_args.tp_size
base_gpu_id += server_args.tp_size * server_args.gpu_id_step
# Free all sockets before starting the threads to launch TP workers
for sock in sockets:
@@ -177,7 +177,11 @@ class DataParallelController:
rank_port_args.nccl_port = port_args.nccl_port
reader, writer = mp.Pipe(duplex=False)
gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node
gpu_id = (
server_args.base_gpu_id
+ base_gpu_id
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
)
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),

View File

@@ -449,6 +449,8 @@ class UpdateWeightsFromDistributedReqOutput:
@dataclass
class UpdateWeightsFromTensorReqInput:
serialized_named_tensors: bytes # indeed Dict[str, torch.Tensor]
load_format: Optional[str]
flush_cache: bool
@dataclass

View File

@@ -1760,8 +1760,9 @@ class Scheduler:
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
if success:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
if recv_req.flush_cache:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return UpdateWeightsFromTensorReqOutput(success, message)

View File

@@ -205,7 +205,10 @@ class TpModelWorker:
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
success, message = self.model_runner.update_weights_from_tensor(
MultiprocessingSerializer.deserialize(recv_req.serialized_named_tensors)
named_tensors=MultiprocessingSerializer.deserialize(
recv_req.serialized_named_tensors
),
load_format=recv_req.load_format,
)
return success, message

View File

@@ -17,7 +17,8 @@ import gc
import json
import logging
import time
from typing import List, Optional, Tuple
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
@@ -56,10 +57,12 @@ from sglang.srt.mem_cache.memory_pool import (
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader import get_model
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import (
MultiprocessingSerializer,
enable_show_time_cost,
get_available_gpu_memory,
init_custom_process_group,
@@ -514,8 +517,21 @@ class ModelRunner:
logger.error(error_msg)
return False, error_msg
def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
self.model.load_weights(named_tensors)
def update_weights_from_tensor(
self,
named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
load_format: Optional[str] = None,
):
named_tensors = [
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
for name, tensor in named_tensors
]
if load_format == "direct":
_model_load_weights_direct(self.model, named_tensors)
elif load_format is None:
self.model.load_weights(named_tensors)
else:
raise NotImplementedError(f"Unknown load_format={load_format}")
return True, "Success"
def get_weights_by_name(
@@ -836,3 +852,26 @@ class ModelRunner:
if rope_scaling is None:
return False
return rope_scaling.get("type", None) == "mrope"
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
params_dict = dict(model.named_parameters())
for name, tensor in named_tensors:
default_weight_loader(params_dict[name], tensor)
def _unwrap_tensor(tensor, tp_rank):
if isinstance(tensor, LocalSerializedTensor):
return tensor.get(tp_rank)
return tensor
@dataclass
class LocalSerializedTensor:
"""torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
The i-th element in the list corresponds to i-th rank's GPU."""
values: List[bytes]
def get(self, rank: int):
return MultiprocessingSerializer.deserialize(self.values[rank])

View File

@@ -336,12 +336,6 @@ class GemmaForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError(
"Some weights are not initialized from checkpoints: "
f"{unloaded_params}"
)
EntryClass = GemmaForCausalLM

View File

@@ -437,12 +437,5 @@ class Gemma2ForCausalLM(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError(
"Some weights are not initialized from checkpoints: "
f"{unloaded_params}"
)
EntryClass = Gemma2ForCausalLM

View File

@@ -82,6 +82,7 @@ class ServerArgs:
dist_timeout: Optional[int] = None # timeout for torch.distributed
download_dir: Optional[str] = None
base_gpu_id: int = 0
gpu_id_step: int = 1
# Logging
log_level: str = "info"
@@ -552,6 +553,12 @@ class ServerArgs:
default=ServerArgs.base_gpu_id,
help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
)
parser.add_argument(
"--gpu-id-step",
type=int,
default=ServerArgs.gpu_id_step,
help="The delta between consecutive GPU IDs that are used. For example, setting it to 2 will use GPU 0,2,4,...",
)
# Logging
parser.add_argument(
@@ -957,6 +964,7 @@ class ServerArgs:
and (self.lora_paths is None or self.disable_radix_cache)
), "compatibility of lora and cuda graph and radix attention is in progress"
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
if isinstance(self.lora_paths, list):
lora_paths = self.lora_paths

View File

@@ -1386,7 +1386,6 @@ def get_ip() -> str:
def get_open_port() -> int:
port = os.getenv("SGLANG_PORT")
if port is not None:
while True:

View File

@@ -21,9 +21,9 @@ import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM
from sglang.srt.entrypoints.engine import Engine
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
from sglang.srt.server import Engine
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l
DEFAULT_PROMPTS = [
"Apple is red. Banana is Yellow. " * 800 + "Apple is",
@@ -95,9 +95,11 @@ class HFRunner:
torch_dtype: torch.dtype,
model_type: str = "generation",
output_str_only: bool = False,
trust_remote_code: bool = False,
):
self.model_type = model_type
self.output_str_only = output_str_only
self.trust_remote_code = trust_remote_code
self.in_queue = mp.Queue()
self.out_queue = mp.Queue()
@@ -130,7 +132,7 @@ class HFRunner:
self.base_model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
trust_remote_code=False,
trust_remote_code=self.trust_remote_code,
low_cpu_mem_usage=True,
).cuda()
elif self.model_type == "embedding":
@@ -147,7 +149,11 @@ class HFRunner:
).cuda()
else:
raise Exception(f"Unrecognized model type {self.model_type}")
self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype)
self.tokenizer = get_tokenizer(
model_path,
torch_dtype=torch.dtype,
trust_remote_code=self.trust_remote_code,
)
# Run forward
while True:
@@ -157,74 +163,15 @@ class HFRunner:
if prompts is not None:
if self.model_type == "generation":
output_strs = []
top_input_logprobs = []
top_output_logprobs = []
for i, p in enumerate(prompts):
if isinstance(p, str):
input_ids = self.tokenizer.encode(
p, return_tensors="pt"
).cuda()
else:
input_ids = torch.tensor([p], device="cuda")
if lora_paths is not None and lora_paths[i] is not None:
from peft import PeftModel
self.model = PeftModel.from_pretrained(
self.base_model,
lora_paths[i],
torch_dtype=torch_dtype,
is_trainable=False,
)
else:
self.model = self.base_model
outputs = self.model.generate(
input_ids,
do_sample=False,
temperature=None,
top_p=None,
max_new_tokens=max_new_tokens,
return_dict_in_generate=True,
output_scores=(not self.output_str_only),
)
text = self.tokenizer.decode(
outputs[0][0][len(input_ids[0]) :], skip_special_tokens=True
)
# Check if the text is empty or only whitespace.
if not text.strip():
raise ValueError(
"Received an empty text response. Please verify your input or model configuration."
)
output_strs.append(text)
if not self.output_str_only:
# outputs.scores: (num_token, 1, vocab_size)
top_output_logprobs.append(
[
get_top_logprobs(
logits[0], NUM_TOP_LOGPROBS
).tolist()
for logits in outputs.scores
]
)
del outputs
input_logits = self.model.forward(input_ids).logits[0]
top_input_logprobs.append(
get_top_logprobs(
input_logits, NUM_TOP_LOGPROBS
).tolist()
)
del input_logits
out_queue.put(
ModelOutput(
output_strs=output_strs,
top_input_logprobs=top_input_logprobs,
top_output_logprobs=top_output_logprobs,
self.forward_generation_raw(
prompts=prompts,
max_new_tokens=max_new_tokens,
base_model=self.base_model,
tokenizer=self.tokenizer,
lora_paths=lora_paths,
torch_dtype=torch_dtype,
output_str_only=self.output_str_only,
)
)
@@ -269,6 +216,79 @@ class HFRunner:
self.model_proc.terminate()
self.in_queue = self.out_queue = None
@staticmethod
def forward_generation_raw(
prompts: Union[List[str], List[torch.Tensor]],
max_new_tokens,
base_model,
tokenizer,
lora_paths,
torch_dtype: torch.dtype,
output_str_only: bool,
) -> ModelOutput:
output_strs = []
top_input_logprobs = []
top_output_logprobs = []
for i, p in enumerate(prompts):
if isinstance(p, str):
input_ids = tokenizer.encode(p, return_tensors="pt").cuda()
else:
input_ids = torch.tensor([p], device="cuda")
if lora_paths is not None and lora_paths[i] is not None:
from peft import PeftModel
model = PeftModel.from_pretrained(
base_model,
lora_paths[i],
torch_dtype=torch_dtype,
is_trainable=False,
)
else:
model = base_model
outputs = model.generate(
input_ids,
do_sample=False,
temperature=None,
top_p=None,
max_new_tokens=max_new_tokens,
return_dict_in_generate=True,
output_scores=(not output_str_only),
)
text = tokenizer.decode(
outputs[0][0][len(input_ids[0]) :], skip_special_tokens=True
)
# Check if the text is empty or only whitespace.
if not text.strip():
raise ValueError(
"Received an empty text response. Please verify your input or model configuration."
)
output_strs.append(text)
if not output_str_only:
# outputs.scores: (num_token, 1, vocab_size)
top_output_logprobs.append(
[
get_top_logprobs(logits[0], NUM_TOP_LOGPROBS).tolist()
for logits in outputs.scores
]
)
del outputs
input_logits = model.forward(input_ids).logits[0]
top_input_logprobs.append(
get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist()
)
del input_logits
return ModelOutput(
output_strs=output_strs,
top_input_logprobs=top_input_logprobs,
top_output_logprobs=top_output_logprobs,
)
class SRTRunner:
def __init__(
@@ -284,6 +304,7 @@ class SRTRunner:
disable_cuda_graph: bool = False,
disable_radix_cache: bool = False,
mem_fraction_static: float = 0.65,
trust_remote_code: bool = False,
):
self.model_type = model_type
self.is_generation = model_type == "generation"
@@ -293,7 +314,7 @@ class SRTRunner:
dtype=get_dtype_str(torch_dtype),
port=port,
mem_fraction_static=mem_fraction_static,
trust_remote_code=False,
trust_remote_code=trust_remote_code,
is_embedding=not self.is_generation,
lora_paths=lora_paths,
max_loras_per_batch=max_loras_per_batch,
@@ -301,7 +322,7 @@ class SRTRunner:
disable_cuda_graph=disable_cuda_graph,
disable_radix_cache=disable_radix_cache,
)
self.tokenizer = get_tokenizer(model_path)
self.tokenizer = get_tokenizer(model_path, trust_remote_code=trust_remote_code)
def forward(
self,
@@ -310,54 +331,11 @@ class SRTRunner:
lora_paths=None,
):
if self.is_generation:
# the return value contains logprobs from prefill
output_strs = []
top_input_logprobs = []
top_output_logprobs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
for i, prompt in enumerate(prompts):
response = self.engine.generate(
prompt,
lora_path=lora_paths[i] if lora_paths else None,
sampling_params=sampling_params,
return_logprob=True,
logprob_start_len=0,
top_logprobs_num=NUM_TOP_LOGPROBS,
)
text = response["text"]
# Check if the text is empty or only whitespace.
if not text.strip():
raise ValueError(
"Received an empty text response. Please verify your input or model configuration."
)
output_strs.append(text)
top_input_logprobs.append(
[
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["input_top_logprobs"][1:]
]
+ [
[
tup[0]
for tup in response["meta_info"]["output_top_logprobs"][0][
:NUM_TOP_LOGPROBS
]
]
]
)
top_output_logprobs.append(
[
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["output_top_logprobs"]
]
)
return ModelOutput(
output_strs=output_strs,
top_input_logprobs=top_input_logprobs,
top_output_logprobs=top_output_logprobs,
return self.forward_generation_raw(
prompts=prompts,
max_new_tokens=max_new_tokens,
lora_paths=lora_paths,
engine=self.engine,
)
else:
response = self.engine.encode(prompts)
@@ -379,18 +357,11 @@ class SRTRunner:
only return output strings and no logprobs
"""
if self.is_generation:
# the return value contains logprobs from prefill
output_strs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
response = self.engine.generate(
prompts,
lora_path=lora_paths if lora_paths else None,
sampling_params=sampling_params,
)
output_strs = [r["text"] for r in response]
return ModelOutput(
output_strs=output_strs,
return self.batch_forward_generation_raw(
prompts=prompts,
max_new_tokens=max_new_tokens,
lora_paths=lora_paths,
engine=self.engine,
)
else:
response = self.engine.encode(prompts)
@@ -408,6 +379,84 @@ class SRTRunner:
self.engine.shutdown()
del self.engine
@staticmethod
def forward_generation_raw(
prompts: Union[List[str], List[torch.Tensor]],
max_new_tokens,
lora_paths,
engine,
):
# the return value contains logprobs from prefill
output_strs = []
top_input_logprobs = []
top_output_logprobs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
for i, prompt in enumerate(prompts):
response = engine.generate(
prompt,
lora_path=lora_paths[i] if lora_paths else None,
sampling_params=sampling_params,
return_logprob=True,
logprob_start_len=0,
top_logprobs_num=NUM_TOP_LOGPROBS,
)
text = response["text"]
# Check if the text is empty or only whitespace.
if not text.strip():
raise ValueError(
"Received an empty text response. Please verify your input or model configuration."
)
output_strs.append(text)
top_input_logprobs.append(
[
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["input_top_logprobs"][1:]
]
+ [
[
tup[0]
for tup in response["meta_info"]["output_top_logprobs"][0][
:NUM_TOP_LOGPROBS
]
]
]
)
top_output_logprobs.append(
[
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["output_top_logprobs"]
]
)
return ModelOutput(
output_strs=output_strs,
top_input_logprobs=top_input_logprobs,
top_output_logprobs=top_output_logprobs,
)
@staticmethod
def batch_forward_generation_raw(
prompts: Union[List[str], List[torch.Tensor]],
max_new_tokens,
lora_paths,
engine,
):
# the return value contains logprobs from prefill
output_strs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
response = engine.generate(
prompts,
lora_path=lora_paths if lora_paths else None,
sampling_params=sampling_params,
)
output_strs = [r["text"] for r in response]
return ModelOutput(
output_strs=output_strs,
)
def monkey_patch_gemma2_sdpa():
"""
@@ -422,3 +471,52 @@ def monkey_patch_gemma2_sdpa():
return config
setattr(Gemma2PreTrainedModel, "_check_and_enable_sdpa", _check_and_enable_sdpa)
def check_close_model_outputs(
hf_outputs: ModelOutput,
srt_outputs: ModelOutput,
prefill_tolerance: float,
decode_tolerance: float,
rouge_l_tolerance: float,
debug_text: str = "",
check_logprobs: bool = True,
):
# Compare output strings
print(f"{hf_outputs.output_strs=}")
print(f"{srt_outputs.output_strs=}")
rouge_l_scores = calculate_rouge_l(hf_outputs.output_strs, srt_outputs.output_strs)
print(f"{rouge_l_scores=}")
assert all(
score >= rouge_l_tolerance for score in rouge_l_scores
), f"Not all ROUGE-L scores are greater than rouge_l_tolerance={rouge_l_tolerance}"
if check_logprobs:
for i in range(len(hf_outputs.output_strs)):
# Compare input logprobs
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
input_len = hf_logprobs.shape[0]
print(
"prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
)
if input_len <= 100:
assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
f"prefill logprobs are not all close with {debug_text} "
f"prefill_tolerance={prefill_tolerance}."
f"{hf_logprobs=}, {srt_logprobs=}"
)
# Compare output logprobs
hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
print(
"decode logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
)
if input_len <= 100:
assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
f"decode logprobs are not all close with {debug_text} "
f"decode_tolerance={decode_tolerance}."
f"{hf_logprobs=}, {srt_logprobs=}"
)

View File

@@ -536,7 +536,7 @@ def test_hellaswag_select():
# Compute accuracy
accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
print(f"{accuracy=}, {accuracy_gen=}")
assert np.abs(accuracy_gen - accuracy) < 0.05
assert np.abs(accuracy_gen - accuracy) < 0.1
assert np.abs(latency_gen - latency) < 1
return accuracy, latency