[CI] Return output logprobs in unit test (#1361)
This commit is contained in:
@@ -50,6 +50,12 @@ def get_dtype_str(torch_dtype):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
def get_top_logprobs(logits, k):
|
||||||
|
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||||
|
logprobs, top_indices = torch.topk(logprobs, k=k, dim=-1)
|
||||||
|
return logprobs
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelOutput:
|
class ModelOutput:
|
||||||
output_strs: List[str] = None
|
output_strs: List[str] = None
|
||||||
@@ -108,7 +114,8 @@ class HFRunner:
|
|||||||
if prompts is not None:
|
if prompts is not None:
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
output_strs = []
|
output_strs = []
|
||||||
prefill_logprobs = []
|
top_input_logprobs = []
|
||||||
|
top_output_logprobs = []
|
||||||
for p in prompts:
|
for p in prompts:
|
||||||
if isinstance(p, str):
|
if isinstance(p, str):
|
||||||
input_ids = self.tokenizer.encode(
|
input_ids = self.tokenizer.encode(
|
||||||
@@ -117,32 +124,43 @@ class HFRunner:
|
|||||||
else:
|
else:
|
||||||
input_ids = torch.tensor([p], device="cuda")
|
input_ids = torch.tensor([p], device="cuda")
|
||||||
|
|
||||||
output_ids = self.model.generate(
|
outputs = self.model.generate(
|
||||||
input_ids, do_sample=False, max_new_tokens=max_new_tokens
|
input_ids,
|
||||||
|
do_sample=False,
|
||||||
|
temperature=None,
|
||||||
|
top_p=None,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
output_scores=True,
|
||||||
)
|
)
|
||||||
output_strs.append(
|
output_strs.append(
|
||||||
self.tokenizer.decode(output_ids[0][len(input_ids[0]) :])
|
self.tokenizer.decode(outputs[0][0][len(input_ids[0]) :])
|
||||||
)
|
)
|
||||||
|
# outputs.scores: (num_token, 1, vocab_size)
|
||||||
|
top_output_logprobs.append(
|
||||||
|
[
|
||||||
|
get_top_logprobs(logits[0], NUM_TOP_LOGPROBS).tolist()
|
||||||
|
for logits in outputs.scores
|
||||||
|
]
|
||||||
|
)
|
||||||
|
del outputs
|
||||||
|
|
||||||
logits = self.model.forward(input_ids).logits[0]
|
input_logits = self.model.forward(input_ids).logits[0]
|
||||||
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
top_input_logprobs.append(
|
||||||
logprobs, top_indices = torch.topk(
|
get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist()
|
||||||
logprobs, k=NUM_TOP_LOGPROBS, dim=-1
|
|
||||||
)
|
)
|
||||||
# print("index", top_indices)
|
del input_logits
|
||||||
prefill_logprobs.append(logprobs.tolist())
|
|
||||||
del logits
|
|
||||||
del logprobs
|
|
||||||
|
|
||||||
out_queue.put(
|
out_queue.put(
|
||||||
ModelOutput(
|
ModelOutput(
|
||||||
output_strs=output_strs, top_input_logprobs=prefill_logprobs
|
output_strs=output_strs,
|
||||||
|
top_input_logprobs=top_input_logprobs,
|
||||||
|
top_output_logprobs=top_output_logprobs,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logits = self.model.encode(prompts).tolist()
|
logits = self.model.encode(prompts).tolist()
|
||||||
|
|
||||||
out_queue.put(ModelOutput(embed_logits=logits))
|
out_queue.put(ModelOutput(embed_logits=logits))
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -194,6 +212,7 @@ class SRTRunner:
|
|||||||
# the return value contains logprobs from prefill
|
# the return value contains logprobs from prefill
|
||||||
output_strs = []
|
output_strs = []
|
||||||
top_input_logprobs = []
|
top_input_logprobs = []
|
||||||
|
top_output_logprobs = []
|
||||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
response = self.runtime.generate(
|
response = self.runtime.generate(
|
||||||
@@ -219,9 +238,17 @@ class SRTRunner:
|
|||||||
]
|
]
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
top_output_logprobs.append(
|
||||||
|
[
|
||||||
|
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
||||||
|
for x in response["meta_info"]["output_top_logprobs"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
return ModelOutput(
|
return ModelOutput(
|
||||||
output_strs=output_strs, top_input_logprobs=top_input_logprobs
|
output_strs=output_strs,
|
||||||
|
top_input_logprobs=top_input_logprobs,
|
||||||
|
top_output_logprobs=top_output_logprobs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = self.runtime.encode(prompts)
|
response = self.runtime.encode(prompts)
|
||||||
|
|||||||
@@ -21,9 +21,9 @@ import torch
|
|||||||
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1, 3e-2, 1),
|
("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1, 3e-2, 4e-2, 1),
|
||||||
("google/gemma-2-2b", 1, 3, 3e-2, 1),
|
("google/gemma-2-2b", 1, 3, 3e-2, 5e-2, 1),
|
||||||
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, None, 6e-2, 1),
|
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, None, 6e-2, 4e-2, 1),
|
||||||
]
|
]
|
||||||
TORCH_DTYPES = [torch.float16]
|
TORCH_DTYPES = [torch.float16]
|
||||||
|
|
||||||
@@ -70,6 +70,7 @@ class TestGenerationModels(unittest.TestCase):
|
|||||||
torch_dtype,
|
torch_dtype,
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
prefill_tolerance,
|
prefill_tolerance,
|
||||||
|
output_tolerance,
|
||||||
rouge_threshold,
|
rouge_threshold,
|
||||||
long_context_tolerance,
|
long_context_tolerance,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -89,15 +90,37 @@ class TestGenerationModels(unittest.TestCase):
|
|||||||
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
|
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
|
||||||
|
|
||||||
for i in range(len(prompts)):
|
for i in range(len(prompts)):
|
||||||
|
# input logprobs comparison
|
||||||
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
|
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
|
||||||
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
|
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
|
||||||
|
input_len = hf_logprobs.shape[0]
|
||||||
print("max_diff", torch.max(abs(hf_logprobs - srt_logprobs)))
|
print(
|
||||||
if hf_logprobs.shape[0] <= 100:
|
"prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
|
||||||
|
)
|
||||||
|
if input_len <= 100:
|
||||||
assert torch.all(
|
assert torch.all(
|
||||||
abs(hf_logprobs - srt_logprobs) < prefill_tolerance
|
abs(hf_logprobs - srt_logprobs) < prefill_tolerance
|
||||||
), f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} prefill_tolerance={prefill_tolerance}"
|
), f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} prefill_tolerance={prefill_tolerance}"
|
||||||
|
|
||||||
|
# output logprobs comparison
|
||||||
|
hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
|
||||||
|
srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
|
||||||
|
# print(
|
||||||
|
# "output logprobs diff",
|
||||||
|
# [
|
||||||
|
# float(torch.max(abs(hf_logprobs[j] - srt_logprobs[j])))
|
||||||
|
# for j in range(max_new_tokens)
|
||||||
|
# ],
|
||||||
|
# )
|
||||||
|
print(
|
||||||
|
"output logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
|
||||||
|
)
|
||||||
|
if input_len <= 100:
|
||||||
|
assert torch.all(
|
||||||
|
abs(hf_logprobs - srt_logprobs) < output_tolerance
|
||||||
|
), f"output logprobs are not all close with model_path={model_path} prompts={prompts}... output_tolerance={output_tolerance}"
|
||||||
|
|
||||||
|
# output strings comparison
|
||||||
print(f"hf_outputs.output_strs={hf_outputs.output_strs}")
|
print(f"hf_outputs.output_strs={hf_outputs.output_strs}")
|
||||||
print(f"srt_outputs.output_strs={srt_outputs.output_strs}")
|
print(f"srt_outputs.output_strs={srt_outputs.output_strs}")
|
||||||
rouge_l_scores = calculate_rouge_l(
|
rouge_l_scores = calculate_rouge_l(
|
||||||
@@ -114,6 +137,7 @@ class TestGenerationModels(unittest.TestCase):
|
|||||||
tp_size,
|
tp_size,
|
||||||
long_context_tolerance,
|
long_context_tolerance,
|
||||||
prefill_tolerance,
|
prefill_tolerance,
|
||||||
|
output_tolerance,
|
||||||
rouge_threshold,
|
rouge_threshold,
|
||||||
) in MODELS:
|
) in MODELS:
|
||||||
for torch_dtype in TORCH_DTYPES:
|
for torch_dtype in TORCH_DTYPES:
|
||||||
@@ -125,6 +149,7 @@ class TestGenerationModels(unittest.TestCase):
|
|||||||
torch_dtype,
|
torch_dtype,
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
prefill_tolerance=prefill_tolerance,
|
prefill_tolerance=prefill_tolerance,
|
||||||
|
output_tolerance=output_tolerance,
|
||||||
rouge_threshold=rouge_threshold,
|
rouge_threshold=rouge_threshold,
|
||||||
long_context_tolerance=long_context_tolerance,
|
long_context_tolerance=long_context_tolerance,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user