[Feature] SPMD for SGLang + Verl (#3852)
This commit is contained in:
@@ -21,9 +21,9 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from sglang.srt.entrypoints.engine import Engine
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
|
||||
from sglang.srt.server import Engine
|
||||
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l
|
||||
|
||||
DEFAULT_PROMPTS = [
|
||||
"Apple is red. Banana is Yellow. " * 800 + "Apple is",
|
||||
@@ -95,9 +95,11 @@ class HFRunner:
|
||||
torch_dtype: torch.dtype,
|
||||
model_type: str = "generation",
|
||||
output_str_only: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.model_type = model_type
|
||||
self.output_str_only = output_str_only
|
||||
self.trust_remote_code = trust_remote_code
|
||||
|
||||
self.in_queue = mp.Queue()
|
||||
self.out_queue = mp.Queue()
|
||||
@@ -130,7 +132,7 @@ class HFRunner:
|
||||
self.base_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=False,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
low_cpu_mem_usage=True,
|
||||
).cuda()
|
||||
elif self.model_type == "embedding":
|
||||
@@ -147,7 +149,11 @@ class HFRunner:
|
||||
).cuda()
|
||||
else:
|
||||
raise Exception(f"Unrecognized model type {self.model_type}")
|
||||
self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype)
|
||||
self.tokenizer = get_tokenizer(
|
||||
model_path,
|
||||
torch_dtype=torch.dtype,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
)
|
||||
|
||||
# Run forward
|
||||
while True:
|
||||
@@ -157,74 +163,15 @@ class HFRunner:
|
||||
|
||||
if prompts is not None:
|
||||
if self.model_type == "generation":
|
||||
output_strs = []
|
||||
top_input_logprobs = []
|
||||
top_output_logprobs = []
|
||||
for i, p in enumerate(prompts):
|
||||
if isinstance(p, str):
|
||||
input_ids = self.tokenizer.encode(
|
||||
p, return_tensors="pt"
|
||||
).cuda()
|
||||
else:
|
||||
input_ids = torch.tensor([p], device="cuda")
|
||||
|
||||
if lora_paths is not None and lora_paths[i] is not None:
|
||||
from peft import PeftModel
|
||||
|
||||
self.model = PeftModel.from_pretrained(
|
||||
self.base_model,
|
||||
lora_paths[i],
|
||||
torch_dtype=torch_dtype,
|
||||
is_trainable=False,
|
||||
)
|
||||
else:
|
||||
self.model = self.base_model
|
||||
|
||||
outputs = self.model.generate(
|
||||
input_ids,
|
||||
do_sample=False,
|
||||
temperature=None,
|
||||
top_p=None,
|
||||
max_new_tokens=max_new_tokens,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=(not self.output_str_only),
|
||||
)
|
||||
|
||||
text = self.tokenizer.decode(
|
||||
outputs[0][0][len(input_ids[0]) :], skip_special_tokens=True
|
||||
)
|
||||
# Check if the text is empty or only whitespace.
|
||||
if not text.strip():
|
||||
raise ValueError(
|
||||
"Received an empty text response. Please verify your input or model configuration."
|
||||
)
|
||||
output_strs.append(text)
|
||||
|
||||
if not self.output_str_only:
|
||||
# outputs.scores: (num_token, 1, vocab_size)
|
||||
top_output_logprobs.append(
|
||||
[
|
||||
get_top_logprobs(
|
||||
logits[0], NUM_TOP_LOGPROBS
|
||||
).tolist()
|
||||
for logits in outputs.scores
|
||||
]
|
||||
)
|
||||
del outputs
|
||||
|
||||
input_logits = self.model.forward(input_ids).logits[0]
|
||||
top_input_logprobs.append(
|
||||
get_top_logprobs(
|
||||
input_logits, NUM_TOP_LOGPROBS
|
||||
).tolist()
|
||||
)
|
||||
del input_logits
|
||||
|
||||
out_queue.put(
|
||||
ModelOutput(
|
||||
output_strs=output_strs,
|
||||
top_input_logprobs=top_input_logprobs,
|
||||
top_output_logprobs=top_output_logprobs,
|
||||
self.forward_generation_raw(
|
||||
prompts=prompts,
|
||||
max_new_tokens=max_new_tokens,
|
||||
base_model=self.base_model,
|
||||
tokenizer=self.tokenizer,
|
||||
lora_paths=lora_paths,
|
||||
torch_dtype=torch_dtype,
|
||||
output_str_only=self.output_str_only,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -269,6 +216,79 @@ class HFRunner:
|
||||
self.model_proc.terminate()
|
||||
self.in_queue = self.out_queue = None
|
||||
|
||||
@staticmethod
|
||||
def forward_generation_raw(
|
||||
prompts: Union[List[str], List[torch.Tensor]],
|
||||
max_new_tokens,
|
||||
base_model,
|
||||
tokenizer,
|
||||
lora_paths,
|
||||
torch_dtype: torch.dtype,
|
||||
output_str_only: bool,
|
||||
) -> ModelOutput:
|
||||
output_strs = []
|
||||
top_input_logprobs = []
|
||||
top_output_logprobs = []
|
||||
for i, p in enumerate(prompts):
|
||||
if isinstance(p, str):
|
||||
input_ids = tokenizer.encode(p, return_tensors="pt").cuda()
|
||||
else:
|
||||
input_ids = torch.tensor([p], device="cuda")
|
||||
|
||||
if lora_paths is not None and lora_paths[i] is not None:
|
||||
from peft import PeftModel
|
||||
|
||||
model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
lora_paths[i],
|
||||
torch_dtype=torch_dtype,
|
||||
is_trainable=False,
|
||||
)
|
||||
else:
|
||||
model = base_model
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids,
|
||||
do_sample=False,
|
||||
temperature=None,
|
||||
top_p=None,
|
||||
max_new_tokens=max_new_tokens,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=(not output_str_only),
|
||||
)
|
||||
|
||||
text = tokenizer.decode(
|
||||
outputs[0][0][len(input_ids[0]) :], skip_special_tokens=True
|
||||
)
|
||||
# Check if the text is empty or only whitespace.
|
||||
if not text.strip():
|
||||
raise ValueError(
|
||||
"Received an empty text response. Please verify your input or model configuration."
|
||||
)
|
||||
output_strs.append(text)
|
||||
|
||||
if not output_str_only:
|
||||
# outputs.scores: (num_token, 1, vocab_size)
|
||||
top_output_logprobs.append(
|
||||
[
|
||||
get_top_logprobs(logits[0], NUM_TOP_LOGPROBS).tolist()
|
||||
for logits in outputs.scores
|
||||
]
|
||||
)
|
||||
del outputs
|
||||
|
||||
input_logits = model.forward(input_ids).logits[0]
|
||||
top_input_logprobs.append(
|
||||
get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist()
|
||||
)
|
||||
del input_logits
|
||||
|
||||
return ModelOutput(
|
||||
output_strs=output_strs,
|
||||
top_input_logprobs=top_input_logprobs,
|
||||
top_output_logprobs=top_output_logprobs,
|
||||
)
|
||||
|
||||
|
||||
class SRTRunner:
|
||||
def __init__(
|
||||
@@ -284,6 +304,7 @@ class SRTRunner:
|
||||
disable_cuda_graph: bool = False,
|
||||
disable_radix_cache: bool = False,
|
||||
mem_fraction_static: float = 0.65,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.model_type = model_type
|
||||
self.is_generation = model_type == "generation"
|
||||
@@ -293,7 +314,7 @@ class SRTRunner:
|
||||
dtype=get_dtype_str(torch_dtype),
|
||||
port=port,
|
||||
mem_fraction_static=mem_fraction_static,
|
||||
trust_remote_code=False,
|
||||
trust_remote_code=trust_remote_code,
|
||||
is_embedding=not self.is_generation,
|
||||
lora_paths=lora_paths,
|
||||
max_loras_per_batch=max_loras_per_batch,
|
||||
@@ -301,7 +322,7 @@ class SRTRunner:
|
||||
disable_cuda_graph=disable_cuda_graph,
|
||||
disable_radix_cache=disable_radix_cache,
|
||||
)
|
||||
self.tokenizer = get_tokenizer(model_path)
|
||||
self.tokenizer = get_tokenizer(model_path, trust_remote_code=trust_remote_code)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -310,54 +331,11 @@ class SRTRunner:
|
||||
lora_paths=None,
|
||||
):
|
||||
if self.is_generation:
|
||||
# the return value contains logprobs from prefill
|
||||
output_strs = []
|
||||
top_input_logprobs = []
|
||||
top_output_logprobs = []
|
||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||
for i, prompt in enumerate(prompts):
|
||||
response = self.engine.generate(
|
||||
prompt,
|
||||
lora_path=lora_paths[i] if lora_paths else None,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=True,
|
||||
logprob_start_len=0,
|
||||
top_logprobs_num=NUM_TOP_LOGPROBS,
|
||||
)
|
||||
text = response["text"]
|
||||
|
||||
# Check if the text is empty or only whitespace.
|
||||
if not text.strip():
|
||||
raise ValueError(
|
||||
"Received an empty text response. Please verify your input or model configuration."
|
||||
)
|
||||
output_strs.append(text)
|
||||
|
||||
top_input_logprobs.append(
|
||||
[
|
||||
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
||||
for x in response["meta_info"]["input_top_logprobs"][1:]
|
||||
]
|
||||
+ [
|
||||
[
|
||||
tup[0]
|
||||
for tup in response["meta_info"]["output_top_logprobs"][0][
|
||||
:NUM_TOP_LOGPROBS
|
||||
]
|
||||
]
|
||||
]
|
||||
)
|
||||
top_output_logprobs.append(
|
||||
[
|
||||
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
||||
for x in response["meta_info"]["output_top_logprobs"]
|
||||
]
|
||||
)
|
||||
|
||||
return ModelOutput(
|
||||
output_strs=output_strs,
|
||||
top_input_logprobs=top_input_logprobs,
|
||||
top_output_logprobs=top_output_logprobs,
|
||||
return self.forward_generation_raw(
|
||||
prompts=prompts,
|
||||
max_new_tokens=max_new_tokens,
|
||||
lora_paths=lora_paths,
|
||||
engine=self.engine,
|
||||
)
|
||||
else:
|
||||
response = self.engine.encode(prompts)
|
||||
@@ -379,18 +357,11 @@ class SRTRunner:
|
||||
only return output strings and no logprobs
|
||||
"""
|
||||
if self.is_generation:
|
||||
# the return value contains logprobs from prefill
|
||||
output_strs = []
|
||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||
response = self.engine.generate(
|
||||
prompts,
|
||||
lora_path=lora_paths if lora_paths else None,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
output_strs = [r["text"] for r in response]
|
||||
|
||||
return ModelOutput(
|
||||
output_strs=output_strs,
|
||||
return self.batch_forward_generation_raw(
|
||||
prompts=prompts,
|
||||
max_new_tokens=max_new_tokens,
|
||||
lora_paths=lora_paths,
|
||||
engine=self.engine,
|
||||
)
|
||||
else:
|
||||
response = self.engine.encode(prompts)
|
||||
@@ -408,6 +379,84 @@ class SRTRunner:
|
||||
self.engine.shutdown()
|
||||
del self.engine
|
||||
|
||||
@staticmethod
|
||||
def forward_generation_raw(
|
||||
prompts: Union[List[str], List[torch.Tensor]],
|
||||
max_new_tokens,
|
||||
lora_paths,
|
||||
engine,
|
||||
):
|
||||
# the return value contains logprobs from prefill
|
||||
output_strs = []
|
||||
top_input_logprobs = []
|
||||
top_output_logprobs = []
|
||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||
for i, prompt in enumerate(prompts):
|
||||
response = engine.generate(
|
||||
prompt,
|
||||
lora_path=lora_paths[i] if lora_paths else None,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=True,
|
||||
logprob_start_len=0,
|
||||
top_logprobs_num=NUM_TOP_LOGPROBS,
|
||||
)
|
||||
text = response["text"]
|
||||
|
||||
# Check if the text is empty or only whitespace.
|
||||
if not text.strip():
|
||||
raise ValueError(
|
||||
"Received an empty text response. Please verify your input or model configuration."
|
||||
)
|
||||
output_strs.append(text)
|
||||
|
||||
top_input_logprobs.append(
|
||||
[
|
||||
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
||||
for x in response["meta_info"]["input_top_logprobs"][1:]
|
||||
]
|
||||
+ [
|
||||
[
|
||||
tup[0]
|
||||
for tup in response["meta_info"]["output_top_logprobs"][0][
|
||||
:NUM_TOP_LOGPROBS
|
||||
]
|
||||
]
|
||||
]
|
||||
)
|
||||
top_output_logprobs.append(
|
||||
[
|
||||
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
||||
for x in response["meta_info"]["output_top_logprobs"]
|
||||
]
|
||||
)
|
||||
|
||||
return ModelOutput(
|
||||
output_strs=output_strs,
|
||||
top_input_logprobs=top_input_logprobs,
|
||||
top_output_logprobs=top_output_logprobs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def batch_forward_generation_raw(
|
||||
prompts: Union[List[str], List[torch.Tensor]],
|
||||
max_new_tokens,
|
||||
lora_paths,
|
||||
engine,
|
||||
):
|
||||
# the return value contains logprobs from prefill
|
||||
output_strs = []
|
||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||
response = engine.generate(
|
||||
prompts,
|
||||
lora_path=lora_paths if lora_paths else None,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
output_strs = [r["text"] for r in response]
|
||||
|
||||
return ModelOutput(
|
||||
output_strs=output_strs,
|
||||
)
|
||||
|
||||
|
||||
def monkey_patch_gemma2_sdpa():
|
||||
"""
|
||||
@@ -422,3 +471,52 @@ def monkey_patch_gemma2_sdpa():
|
||||
return config
|
||||
|
||||
setattr(Gemma2PreTrainedModel, "_check_and_enable_sdpa", _check_and_enable_sdpa)
|
||||
|
||||
|
||||
def check_close_model_outputs(
|
||||
hf_outputs: ModelOutput,
|
||||
srt_outputs: ModelOutput,
|
||||
prefill_tolerance: float,
|
||||
decode_tolerance: float,
|
||||
rouge_l_tolerance: float,
|
||||
debug_text: str = "",
|
||||
check_logprobs: bool = True,
|
||||
):
|
||||
# Compare output strings
|
||||
print(f"{hf_outputs.output_strs=}")
|
||||
print(f"{srt_outputs.output_strs=}")
|
||||
rouge_l_scores = calculate_rouge_l(hf_outputs.output_strs, srt_outputs.output_strs)
|
||||
print(f"{rouge_l_scores=}")
|
||||
assert all(
|
||||
score >= rouge_l_tolerance for score in rouge_l_scores
|
||||
), f"Not all ROUGE-L scores are greater than rouge_l_tolerance={rouge_l_tolerance}"
|
||||
|
||||
if check_logprobs:
|
||||
for i in range(len(hf_outputs.output_strs)):
|
||||
# Compare input logprobs
|
||||
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
|
||||
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
|
||||
input_len = hf_logprobs.shape[0]
|
||||
print(
|
||||
"prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
|
||||
)
|
||||
if input_len <= 100:
|
||||
assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
|
||||
f"prefill logprobs are not all close with {debug_text} "
|
||||
f"prefill_tolerance={prefill_tolerance}."
|
||||
f"{hf_logprobs=}, {srt_logprobs=}"
|
||||
)
|
||||
|
||||
# Compare output logprobs
|
||||
hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
|
||||
srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
|
||||
|
||||
print(
|
||||
"decode logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
|
||||
)
|
||||
if input_len <= 100:
|
||||
assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
|
||||
f"decode logprobs are not all close with {debug_text} "
|
||||
f"decode_tolerance={decode_tolerance}."
|
||||
f"{hf_logprobs=}, {srt_logprobs=}"
|
||||
)
|
||||
|
||||
@@ -536,7 +536,7 @@ def test_hellaswag_select():
|
||||
# Compute accuracy
|
||||
accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
|
||||
print(f"{accuracy=}, {accuracy_gen=}")
|
||||
assert np.abs(accuracy_gen - accuracy) < 0.05
|
||||
assert np.abs(accuracy_gen - accuracy) < 0.1
|
||||
assert np.abs(latency_gen - latency) < 1
|
||||
|
||||
return accuracy, latency
|
||||
|
||||
Reference in New Issue
Block a user