Revert "fix some typos" (#6244)
This commit is contained in:
@@ -47,8 +47,8 @@ TEST_CUDA_GRAPH_PADDING_PROMPTS = [
|
||||
class TestLoRACudaGraph(CustomTestCase):
|
||||
|
||||
def _run_without_cuda_graph_on_model_cases(self, model_cases: List[LoRAModelCase]):
|
||||
# Since we have already enabled CUDA graph by default in other LoRA tests,
|
||||
# we only need to run LoRA tests without CUDA graph here.
|
||||
# Since we have already enabled CUDA graph by default in other lora tests,
|
||||
# we only need to run lora tests without CUDA graph here.
|
||||
for model_case in model_cases:
|
||||
# If skip_long_prompt is True, filter out prompts longer than 1000 characters
|
||||
prompts = (
|
||||
|
||||
@@ -154,7 +154,7 @@ def run_lora_test_one_by_one(
|
||||
model_case (LoRAModelCase): The model case to test.
|
||||
torch_dtype (torch.dtype): The torch dtype to use.
|
||||
max_new_tokens (int): The maximum number of new tokens to generate.
|
||||
backend (str): The LoRA backend to use.
|
||||
backend (str): The lora backend to use.
|
||||
disable_cuda_graph (bool, optional): Whether to disable CUDA graph. Defaults to False.
|
||||
disable_radix_cache (bool, optional): Whether to disable radix cache. Defaults to True.
|
||||
mem_fraction_static (float, optional): The fraction of memory to use. Defaults to 0.88.
|
||||
@@ -289,7 +289,7 @@ def run_lora_test_by_batch(
|
||||
test_tag: str = "",
|
||||
):
|
||||
"""
|
||||
Run LoRA tests as a batch.
|
||||
Run lora tests as a batch.
|
||||
For prompt0, prompt1, ..., promptN,
|
||||
we will use adaptor0, adaptor1, ..., adaptorN included in model case,
|
||||
We will then compare the outputs of HF and SRT with LoRA.
|
||||
@@ -301,7 +301,7 @@ def run_lora_test_by_batch(
|
||||
model_case (LoRAModelCase): The model case to test.
|
||||
torch_dtype (torch.dtype): The torch dtype to use.
|
||||
max_new_tokens (int): The maximum number of new tokens to generate.
|
||||
backend (str): The LoRA backend to use.
|
||||
backend (str): The lora backend to use.
|
||||
disable_cuda_graph (bool, optional): Whether to disable CUDA graph. Defaults to False.
|
||||
disable_radix_cache (bool, optional): Whether to disable radix cache. Defaults to True.
|
||||
mem_fraction_static (float, optional): The fraction of memory to use. Defaults to 0.88.
|
||||
@@ -372,8 +372,8 @@ def run_lora_test_by_batch(
|
||||
print("ROUGE-L score:", rouge_score)
|
||||
print("SRT output:", srt_output_str)
|
||||
print("HF output:", hf_output_str)
|
||||
print("SRT no LoRA output:", srt_no_lora_outputs.output_strs[i].strip())
|
||||
print("HF no LoRA output:", hf_no_lora_outputs.output_strs[i].strip())
|
||||
print("SRT no lora output:", srt_no_lora_outputs.output_strs[i].strip())
|
||||
print("HF no lora output:", hf_no_lora_outputs.output_strs[i].strip())
|
||||
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i].strip(
|
||||
" "
|
||||
), (
|
||||
|
||||
@@ -8,7 +8,7 @@ class TestSRTEngineWithQuantArgs(CustomTestCase):
|
||||
|
||||
def test_1_quantization_args(self):
|
||||
|
||||
# we only test fp8 because other methods are currently dependent on vLLM. We can add other methods back to test after vLLM dependency is resolved.
|
||||
# we only test fp8 because other methods are currently dependent on vllm. We can add other methods back to test after vllm dependency is resolved.
|
||||
quantization_args_list = [
|
||||
# "awq",
|
||||
"fp8",
|
||||
@@ -34,7 +34,7 @@ class TestSRTEngineWithQuantArgs(CustomTestCase):
|
||||
|
||||
def test_2_torchao_args(self):
|
||||
|
||||
# we don't test int8dq because currently there is conflict between int8dq and capture CUDA graph
|
||||
# we don't test int8dq because currently there is conflict between int8dq and capture cuda graph
|
||||
torchao_args_list = [
|
||||
# "int8dq",
|
||||
"int8wo",
|
||||
|
||||
@@ -277,7 +277,7 @@ class TestTritonAttention(CustomTestCase):
|
||||
|
||||
def test_decode_attention(self):
|
||||
# Here we just to ensure there is no error
|
||||
# TODO: correctness test
|
||||
# TODO: correctnesss test
|
||||
|
||||
# Test configurations
|
||||
configs = [
|
||||
|
||||
@@ -189,7 +189,7 @@ def init_process_hf(
|
||||
print(f"[hf] {rank=} {broadcast_time=:.3f}s")
|
||||
param_queue.put(("broadcast_time", broadcast_time))
|
||||
|
||||
# Delete the HuggingFace models to free up memory.
|
||||
# Delete the huggingface models to free up memory.
|
||||
del hf_instruct_model
|
||||
del hf_base_model
|
||||
gc.collect()
|
||||
|
||||
Reference in New Issue
Block a user