Revert "fix some typos" (#6244)
This commit is contained in:
@@ -154,7 +154,7 @@ def run_lora_test_one_by_one(
|
||||
model_case (LoRAModelCase): The model case to test.
|
||||
torch_dtype (torch.dtype): The torch dtype to use.
|
||||
max_new_tokens (int): The maximum number of new tokens to generate.
|
||||
backend (str): The LoRA backend to use.
|
||||
backend (str): The lora backend to use.
|
||||
disable_cuda_graph (bool, optional): Whether to disable CUDA graph. Defaults to False.
|
||||
disable_radix_cache (bool, optional): Whether to disable radix cache. Defaults to True.
|
||||
mem_fraction_static (float, optional): The fraction of memory to use. Defaults to 0.88.
|
||||
@@ -289,7 +289,7 @@ def run_lora_test_by_batch(
|
||||
test_tag: str = "",
|
||||
):
|
||||
"""
|
||||
Run LoRA tests as a batch.
|
||||
Run lora tests as a batch.
|
||||
For prompt0, prompt1, ..., promptN,
|
||||
we will use adaptor0, adaptor1, ..., adaptorN included in model case,
|
||||
We will then compare the outputs of HF and SRT with LoRA.
|
||||
@@ -301,7 +301,7 @@ def run_lora_test_by_batch(
|
||||
model_case (LoRAModelCase): The model case to test.
|
||||
torch_dtype (torch.dtype): The torch dtype to use.
|
||||
max_new_tokens (int): The maximum number of new tokens to generate.
|
||||
backend (str): The LoRA backend to use.
|
||||
backend (str): The lora backend to use.
|
||||
disable_cuda_graph (bool, optional): Whether to disable CUDA graph. Defaults to False.
|
||||
disable_radix_cache (bool, optional): Whether to disable radix cache. Defaults to True.
|
||||
mem_fraction_static (float, optional): The fraction of memory to use. Defaults to 0.88.
|
||||
@@ -372,8 +372,8 @@ def run_lora_test_by_batch(
|
||||
print("ROUGE-L score:", rouge_score)
|
||||
print("SRT output:", srt_output_str)
|
||||
print("HF output:", hf_output_str)
|
||||
print("SRT no LoRA output:", srt_no_lora_outputs.output_strs[i].strip())
|
||||
print("HF no LoRA output:", hf_no_lora_outputs.output_strs[i].strip())
|
||||
print("SRT no lora output:", srt_no_lora_outputs.output_strs[i].strip())
|
||||
print("HF no lora output:", hf_no_lora_outputs.output_strs[i].strip())
|
||||
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i].strip(
|
||||
" "
|
||||
), (
|
||||
|
||||
Reference in New Issue
Block a user