[Fix] Fix accuracy bug and refactor codes for lora (#3413)

This commit is contained in:
Baizhou Zhang
2025-02-09 21:29:00 -08:00
committed by GitHub
parent 27c4c9cf52
commit c45cab1c00
15 changed files with 1136 additions and 630 deletions

View File

@@ -22,7 +22,11 @@ from sglang.test.test_utils import calculate_rouge_l
LORA_SETS = [
{"base": "meta-llama/Llama-2-7b-hf", "loras": ["winddude/wizardLM-LlaMA-LoRA-7B"]},
# {"base": "meta-llama/Llama-2-7b-hf", "loras": ["RuterNorway/Llama-2-7b-chat-norwegian-LoRa"]}
{
"base": "meta-llama/Llama-3.1-8B-Instruct",
"loras": ["reissbaker/llama-3.1-8b-abliterated-lora"],
"decode_tolerance": 8e-2,
},
]
TORCH_DTYPES = [torch.float16]
@@ -128,7 +132,8 @@ class TestLoRABackend(unittest.TestCase):
torch.max(abs(hf_logprobs - hf_no_lora_logprobs)),
)
if hf_logprobs.shape[0] <= 100:
assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
tol = lora_set.get("prefill_tolerance", prefill_tolerance)
assert torch.all(abs(hf_logprobs - srt_logprobs) < tol), (
f"prefill logprobs are not all close with model_path={base_path},"
f"lora_path={batch_lora_paths[i]}, backend={backend}, prompt={prompts[i]}"
f"prefill_tolerance={prefill_tolerance}."
@@ -144,7 +149,8 @@ class TestLoRABackend(unittest.TestCase):
"\n",
)
if hf_logprobs.shape[0] <= 100:
assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
tol = lora_set.get("decode_tolerance", decode_tolerance)
assert torch.all(abs(hf_logprobs - srt_logprobs) < tol), (
f"decode logprobs are not all close with model_path={base_path},"
f"lora_path={batch_lora_paths[i]}, backend={backend}, prompt={prompts[i]}"
f"decode_tolerance={decode_tolerance}."
@@ -153,7 +159,7 @@ class TestLoRABackend(unittest.TestCase):
# compare output strings
srt_output_str = srt_outputs.output_strs[i].strip(" ")
hf_output_str = hf_outputs.output_strs[i]
hf_output_str = hf_outputs.output_strs[i].strip(" ")
print(f"srt_output_str={srt_output_str}")
print(f"hf_output_str={hf_output_str}")
rouge_l_scores = calculate_rouge_l([srt_output_str], [hf_output_str])