fix some typos (#6209)

Co-authored-by: Brayden Zhong <b8zhong@uwaterloo.ca>
This commit is contained in:
applesaucethebun
2025-05-12 13:42:38 -04:00
committed by GitHub
parent 3ee40ff919
commit d738ab52f8
95 changed files with 276 additions and 276 deletions

View File

@@ -154,7 +154,7 @@ def run_lora_test_one_by_one(
model_case (LoRAModelCase): The model case to test.
torch_dtype (torch.dtype): The torch dtype to use.
max_new_tokens (int): The maximum number of new tokens to generate.
backend (str): The lora backend to use.
backend (str): The LoRA backend to use.
disable_cuda_graph (bool, optional): Whether to disable CUDA graph. Defaults to False.
disable_radix_cache (bool, optional): Whether to disable radix cache. Defaults to True.
mem_fraction_static (float, optional): The fraction of memory to use. Defaults to 0.88.
@@ -289,7 +289,7 @@ def run_lora_test_by_batch(
test_tag: str = "",
):
"""
Run lora tests as a batch.
Run LoRA tests as a batch.
For prompt0, prompt1, ..., promptN,
we will use adaptor0, adaptor1, ..., adaptorN included in model case,
We will then compare the outputs of HF and SRT with LoRA.
@@ -301,7 +301,7 @@ def run_lora_test_by_batch(
model_case (LoRAModelCase): The model case to test.
torch_dtype (torch.dtype): The torch dtype to use.
max_new_tokens (int): The maximum number of new tokens to generate.
backend (str): The lora backend to use.
backend (str): The LoRA backend to use.
disable_cuda_graph (bool, optional): Whether to disable CUDA graph. Defaults to False.
disable_radix_cache (bool, optional): Whether to disable radix cache. Defaults to True.
mem_fraction_static (float, optional): The fraction of memory to use. Defaults to 0.88.
@@ -372,8 +372,8 @@ def run_lora_test_by_batch(
print("ROUGE-L score:", rouge_score)
print("SRT output:", srt_output_str)
print("HF output:", hf_output_str)
print("SRT no lora output:", srt_no_lora_outputs.output_strs[i].strip())
print("HF no lora output:", hf_no_lora_outputs.output_strs[i].strip())
print("SRT no LoRA output:", srt_no_lora_outputs.output_strs[i].strip())
print("HF no LoRA output:", hf_no_lora_outputs.output_strs[i].strip())
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i].strip(
" "
), (