Feat: support cuda graph for LoRA (#4115)

Co-authored-by: Beichen Ma <mabeichen12@gmail.com>
This commit is contained in:
Qiaolin Yu
2025-04-29 02:30:44 -04:00
committed by GitHub
parent 2c3ea29476
commit 8c0cfca87d
13 changed files with 366 additions and 55 deletions

View File

@@ -94,19 +94,20 @@ ALL_OTHER_LORA_MODELS = [
]
def run_batch_lora_test(
def run_lora_test_one_by_one(
prompts: List[str],
model_case: LoRAModelCase,
torch_dtype: torch.dtype,
max_new_tokens: int,
backend: str,
disable_cuda_graph: bool = True,
disable_cuda_graph: bool = False,
disable_radix_cache: bool = True,
mem_fraction_static: float = 0.88,
test_tag: str = "",
):
"""
Run Lora test for a forward batch.
Input a batch of prompts, and run lora tests one by one with several generate requests
(each request will have bs=1).
For prompt0, prompt1, ..., promptN,
we will use adaptor0, adaptor1, ..., adaptorN included in model case,
We will then compare the outputs of HF and SRT with and without LoRA.
@@ -119,7 +120,7 @@ def run_batch_lora_test(
torch_dtype (torch.dtype): The torch dtype to use.
max_new_tokens (int): The maximum number of new tokens to generate.
backend (str): The lora backend to use.
disable_cuda_graph (bool, optional): Whether to disable CUDA graph. Defaults to True.
disable_cuda_graph (bool, optional): Whether to disable CUDA graph. Defaults to False.
disable_radix_cache (bool, optional): Whether to disable radix cache. Defaults to True.
mem_fraction_static (float, optional): The fraction of memory to use. Defaults to 0.88.
test_tag (str, optional): The tag to use for the test. Defaults to "".
@@ -237,3 +238,112 @@ def run_batch_lora_test(
f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} "
f"for base '{base_path}', adaptor '{adaptor_names}', backend '{backend}', prompt: '{prompts[0][:50]}...'"
)
def run_lora_test_by_batch(
prompts: List[str],
model_case: LoRAModelCase,
torch_dtype: torch.dtype,
max_new_tokens: int,
backend: str,
disable_cuda_graph: bool = False,
disable_radix_cache: bool = True,
mem_fraction_static: float = 0.88,
test_tag: str = "",
):
"""
Run lora tests as a batch.
For prompt0, prompt1, ..., promptN,
we will use adaptor0, adaptor1, ..., adaptorN included in model case,
We will then compare the outputs of HF and SRT with LoRA.
If number of prompts is larger than number of adaptors,
the prompt i will use adaptor i % (number of adaptors).
Args:
prompts (List[str]): The batch of prompts to test.
model_case (LoRAModelCase): The model case to test.
torch_dtype (torch.dtype): The torch dtype to use.
max_new_tokens (int): The maximum number of new tokens to generate.
backend (str): The lora backend to use.
disable_cuda_graph (bool, optional): Whether to disable CUDA graph. Defaults to False.
disable_radix_cache (bool, optional): Whether to disable radix cache. Defaults to True.
mem_fraction_static (float, optional): The fraction of memory to use. Defaults to 0.88.
test_tag (str, optional): The tag to use for the test. Defaults to "".
"""
base_path = model_case.base
# Create used adaptors for each prompt in batch
i, adaptors = 0, []
for _ in range(len(prompts)):
adaptors.append(model_case.adaptors[i])
i = (i + 1) % len(model_case.adaptors)
adaptor_names = [adaptor.name for adaptor in adaptors]
print(
f"\n========== Testing {test_tag} on base '{model_case.base}' with backend={backend}, dtype={torch_dtype} --- "
f"Using prompts {[p[:50] for p in prompts]} with adaptors: {adaptor_names} ---"
)
with SRTRunner(
base_path,
torch_dtype=torch_dtype,
model_type="generation",
tp_size=model_case.tp_size,
lora_paths=[adaptor.name for adaptor in model_case.adaptors],
max_loras_per_batch=model_case.max_loras_per_batch,
lora_backend=backend,
disable_cuda_graph=disable_cuda_graph,
disable_radix_cache=disable_radix_cache,
mem_fraction_static=mem_fraction_static,
) as srt_runner:
srt_outputs = srt_runner.batch_forward(
prompts, max_new_tokens=max_new_tokens, lora_paths=adaptor_names
)
with SRTRunner(
base_path,
torch_dtype=torch_dtype,
model_type="generation",
tp_size=model_case.tp_size,
mem_fraction_static=mem_fraction_static,
) as srt_runner:
srt_no_lora_outputs = srt_runner.batch_forward(
prompts, max_new_tokens=max_new_tokens
)
with HFRunner(
base_path, torch_dtype=torch_dtype, model_type="generation"
) as hf_runner:
hf_outputs = hf_runner.forward(
prompts, max_new_tokens=max_new_tokens, lora_paths=adaptor_names
)
with HFRunner(
base_path, torch_dtype=torch_dtype, model_type="generation"
) as hf_runner:
hf_no_lora_outputs = hf_runner.forward(
prompts,
max_new_tokens=max_new_tokens,
)
for i in range(len(prompts)):
srt_output_str = srt_outputs.output_strs[i].strip()
hf_output_str = hf_outputs.output_strs[i].strip()
rouge_score = calculate_rouge_l([srt_output_str], [hf_output_str])[0]
print("ROUGE-L score:", rouge_score)
print("SRT output:", srt_output_str)
print("HF output:", hf_output_str)
print("SRT no lora output:", srt_no_lora_outputs.output_strs[i].strip())
print("HF no lora output:", hf_no_lora_outputs.output_strs[i].strip())
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i].strip(
" "
), (
srt_outputs.output_strs[i].strip(" "),
hf_outputs.output_strs[i].strip(" "),
)
assert srt_no_lora_outputs.output_strs[i].strip(
" "
) == hf_no_lora_outputs.output_strs[i].strip(" "), (
srt_no_lora_outputs.output_strs[i].strip(" "),
hf_no_lora_outputs.output_strs[i].strip(" "),
)