Feat: support cuda graph for LoRA (#4115)
Co-authored-by: Beichen Ma <mabeichen12@gmail.com>
This commit is contained in:
@@ -94,19 +94,20 @@ ALL_OTHER_LORA_MODELS = [
|
||||
]
|
||||
|
||||
|
||||
def run_batch_lora_test(
|
||||
def run_lora_test_one_by_one(
|
||||
prompts: List[str],
|
||||
model_case: LoRAModelCase,
|
||||
torch_dtype: torch.dtype,
|
||||
max_new_tokens: int,
|
||||
backend: str,
|
||||
disable_cuda_graph: bool = True,
|
||||
disable_cuda_graph: bool = False,
|
||||
disable_radix_cache: bool = True,
|
||||
mem_fraction_static: float = 0.88,
|
||||
test_tag: str = "",
|
||||
):
|
||||
"""
|
||||
Run Lora test for a forward batch.
|
||||
Input a batch of prompts, and run lora tests one by one with several generate requests
|
||||
(each request will have bs=1).
|
||||
For prompt0, prompt1, ..., promptN,
|
||||
we will use adaptor0, adaptor1, ..., adaptorN included in model case,
|
||||
We will then compare the outputs of HF and SRT with and without LoRA.
|
||||
@@ -119,7 +120,7 @@ def run_batch_lora_test(
|
||||
torch_dtype (torch.dtype): The torch dtype to use.
|
||||
max_new_tokens (int): The maximum number of new tokens to generate.
|
||||
backend (str): The lora backend to use.
|
||||
disable_cuda_graph (bool, optional): Whether to disable CUDA graph. Defaults to True.
|
||||
disable_cuda_graph (bool, optional): Whether to disable CUDA graph. Defaults to False.
|
||||
disable_radix_cache (bool, optional): Whether to disable radix cache. Defaults to True.
|
||||
mem_fraction_static (float, optional): The fraction of memory to use. Defaults to 0.88.
|
||||
test_tag (str, optional): The tag to use for the test. Defaults to "".
|
||||
@@ -237,3 +238,112 @@ def run_batch_lora_test(
|
||||
f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} "
|
||||
f"for base '{base_path}', adaptor '{adaptor_names}', backend '{backend}', prompt: '{prompts[0][:50]}...'"
|
||||
)
|
||||
|
||||
|
||||
def run_lora_test_by_batch(
|
||||
prompts: List[str],
|
||||
model_case: LoRAModelCase,
|
||||
torch_dtype: torch.dtype,
|
||||
max_new_tokens: int,
|
||||
backend: str,
|
||||
disable_cuda_graph: bool = False,
|
||||
disable_radix_cache: bool = True,
|
||||
mem_fraction_static: float = 0.88,
|
||||
test_tag: str = "",
|
||||
):
|
||||
"""
|
||||
Run lora tests as a batch.
|
||||
For prompt0, prompt1, ..., promptN,
|
||||
we will use adaptor0, adaptor1, ..., adaptorN included in model case,
|
||||
We will then compare the outputs of HF and SRT with LoRA.
|
||||
If number of prompts is larger than number of adaptors,
|
||||
the prompt i will use adaptor i % (number of adaptors).
|
||||
|
||||
Args:
|
||||
prompts (List[str]): The batch of prompts to test.
|
||||
model_case (LoRAModelCase): The model case to test.
|
||||
torch_dtype (torch.dtype): The torch dtype to use.
|
||||
max_new_tokens (int): The maximum number of new tokens to generate.
|
||||
backend (str): The lora backend to use.
|
||||
disable_cuda_graph (bool, optional): Whether to disable CUDA graph. Defaults to False.
|
||||
disable_radix_cache (bool, optional): Whether to disable radix cache. Defaults to True.
|
||||
mem_fraction_static (float, optional): The fraction of memory to use. Defaults to 0.88.
|
||||
test_tag (str, optional): The tag to use for the test. Defaults to "".
|
||||
"""
|
||||
base_path = model_case.base
|
||||
|
||||
# Create used adaptors for each prompt in batch
|
||||
i, adaptors = 0, []
|
||||
for _ in range(len(prompts)):
|
||||
adaptors.append(model_case.adaptors[i])
|
||||
i = (i + 1) % len(model_case.adaptors)
|
||||
adaptor_names = [adaptor.name for adaptor in adaptors]
|
||||
|
||||
print(
|
||||
f"\n========== Testing {test_tag} on base '{model_case.base}' with backend={backend}, dtype={torch_dtype} --- "
|
||||
f"Using prompts {[p[:50] for p in prompts]} with adaptors: {adaptor_names} ---"
|
||||
)
|
||||
with SRTRunner(
|
||||
base_path,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="generation",
|
||||
tp_size=model_case.tp_size,
|
||||
lora_paths=[adaptor.name for adaptor in model_case.adaptors],
|
||||
max_loras_per_batch=model_case.max_loras_per_batch,
|
||||
lora_backend=backend,
|
||||
disable_cuda_graph=disable_cuda_graph,
|
||||
disable_radix_cache=disable_radix_cache,
|
||||
mem_fraction_static=mem_fraction_static,
|
||||
) as srt_runner:
|
||||
srt_outputs = srt_runner.batch_forward(
|
||||
prompts, max_new_tokens=max_new_tokens, lora_paths=adaptor_names
|
||||
)
|
||||
|
||||
with SRTRunner(
|
||||
base_path,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="generation",
|
||||
tp_size=model_case.tp_size,
|
||||
mem_fraction_static=mem_fraction_static,
|
||||
) as srt_runner:
|
||||
srt_no_lora_outputs = srt_runner.batch_forward(
|
||||
prompts, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
||||
with HFRunner(
|
||||
base_path, torch_dtype=torch_dtype, model_type="generation"
|
||||
) as hf_runner:
|
||||
hf_outputs = hf_runner.forward(
|
||||
prompts, max_new_tokens=max_new_tokens, lora_paths=adaptor_names
|
||||
)
|
||||
|
||||
with HFRunner(
|
||||
base_path, torch_dtype=torch_dtype, model_type="generation"
|
||||
) as hf_runner:
|
||||
hf_no_lora_outputs = hf_runner.forward(
|
||||
prompts,
|
||||
max_new_tokens=max_new_tokens,
|
||||
)
|
||||
|
||||
for i in range(len(prompts)):
|
||||
|
||||
srt_output_str = srt_outputs.output_strs[i].strip()
|
||||
hf_output_str = hf_outputs.output_strs[i].strip()
|
||||
rouge_score = calculate_rouge_l([srt_output_str], [hf_output_str])[0]
|
||||
print("ROUGE-L score:", rouge_score)
|
||||
print("SRT output:", srt_output_str)
|
||||
print("HF output:", hf_output_str)
|
||||
print("SRT no lora output:", srt_no_lora_outputs.output_strs[i].strip())
|
||||
print("HF no lora output:", hf_no_lora_outputs.output_strs[i].strip())
|
||||
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i].strip(
|
||||
" "
|
||||
), (
|
||||
srt_outputs.output_strs[i].strip(" "),
|
||||
hf_outputs.output_strs[i].strip(" "),
|
||||
)
|
||||
assert srt_no_lora_outputs.output_strs[i].strip(
|
||||
" "
|
||||
) == hf_no_lora_outputs.output_strs[i].strip(" "), (
|
||||
srt_no_lora_outputs.output_strs[i].strip(" "),
|
||||
hf_no_lora_outputs.output_strs[i].strip(" "),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user