[Fix] Fix accuracy bug and refactor codes for lora (#3413)
This commit is contained in:
@@ -22,7 +22,11 @@ from sglang.test.test_utils import calculate_rouge_l
|
||||
|
||||
LORA_SETS = [
|
||||
{"base": "meta-llama/Llama-2-7b-hf", "loras": ["winddude/wizardLM-LlaMA-LoRA-7B"]},
|
||||
# {"base": "meta-llama/Llama-2-7b-hf", "loras": ["RuterNorway/Llama-2-7b-chat-norwegian-LoRa"]}
|
||||
{
|
||||
"base": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"loras": ["reissbaker/llama-3.1-8b-abliterated-lora"],
|
||||
"decode_tolerance": 8e-2,
|
||||
},
|
||||
]
|
||||
TORCH_DTYPES = [torch.float16]
|
||||
|
||||
@@ -128,7 +132,8 @@ class TestLoRABackend(unittest.TestCase):
|
||||
torch.max(abs(hf_logprobs - hf_no_lora_logprobs)),
|
||||
)
|
||||
if hf_logprobs.shape[0] <= 100:
|
||||
assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
|
||||
tol = lora_set.get("prefill_tolerance", prefill_tolerance)
|
||||
assert torch.all(abs(hf_logprobs - srt_logprobs) < tol), (
|
||||
f"prefill logprobs are not all close with model_path={base_path},"
|
||||
f"lora_path={batch_lora_paths[i]}, backend={backend}, prompt={prompts[i]}"
|
||||
f"prefill_tolerance={prefill_tolerance}."
|
||||
@@ -144,7 +149,8 @@ class TestLoRABackend(unittest.TestCase):
|
||||
"\n",
|
||||
)
|
||||
if hf_logprobs.shape[0] <= 100:
|
||||
assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
|
||||
tol = lora_set.get("decode_tolerance", decode_tolerance)
|
||||
assert torch.all(abs(hf_logprobs - srt_logprobs) < tol), (
|
||||
f"decode logprobs are not all close with model_path={base_path},"
|
||||
f"lora_path={batch_lora_paths[i]}, backend={backend}, prompt={prompts[i]}"
|
||||
f"decode_tolerance={decode_tolerance}."
|
||||
@@ -153,7 +159,7 @@ class TestLoRABackend(unittest.TestCase):
|
||||
|
||||
# compare output strings
|
||||
srt_output_str = srt_outputs.output_strs[i].strip(" ")
|
||||
hf_output_str = hf_outputs.output_strs[i]
|
||||
hf_output_str = hf_outputs.output_strs[i].strip(" ")
|
||||
print(f"srt_output_str={srt_output_str}")
|
||||
print(f"hf_output_str={hf_output_str}")
|
||||
rouge_l_scores = calculate_rouge_l([srt_output_str], [hf_output_str])
|
||||
|
||||
Reference in New Issue
Block a user