Fix flakiness in LoRA batch test. (#7552)
This commit is contained in:
@@ -503,6 +503,7 @@ class SRTRunner:
|
|||||||
disable_overlap_schedule: bool = False,
|
disable_overlap_schedule: bool = False,
|
||||||
disable_custom_all_reduce: bool = False,
|
disable_custom_all_reduce: bool = False,
|
||||||
torchao_config: Optional[str] = None,
|
torchao_config: Optional[str] = None,
|
||||||
|
sleep_on_idle=False,
|
||||||
):
|
):
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.is_generation = model_type == "generation"
|
self.is_generation = model_type == "generation"
|
||||||
@@ -540,6 +541,7 @@ class SRTRunner:
|
|||||||
disable_overlap_schedule=disable_overlap_schedule,
|
disable_overlap_schedule=disable_overlap_schedule,
|
||||||
cuda_graph_max_bs=4,
|
cuda_graph_max_bs=4,
|
||||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||||
|
sleep_on_idle=sleep_on_idle,
|
||||||
**spec_kwargs,
|
**spec_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import random
|
|||||||
import unittest
|
import unittest
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
from utils import (
|
from utils import (
|
||||||
ALL_OTHER_MULTI_LORA_MODELS,
|
ALL_OTHER_MULTI_LORA_MODELS,
|
||||||
CI_MULTI_LORA_MODELS,
|
CI_MULTI_LORA_MODELS,
|
||||||
@@ -46,7 +47,7 @@ TEST_MULTIPLE_BATCH_PROMPTS = [
|
|||||||
The Transformers are large language models,
|
The Transformers are large language models,
|
||||||
They're used to make predictions on text.
|
They're used to make predictions on text.
|
||||||
""",
|
""",
|
||||||
# "AI is a field of computer science focused on", TODO: Add it back after fixing its bug
|
"AI is a field of computer science focused on",
|
||||||
"Computer science is the study of",
|
"Computer science is the study of",
|
||||||
"Write a short story.",
|
"Write a short story.",
|
||||||
"What are the main components of a computer?",
|
"What are the main components of a computer?",
|
||||||
@@ -54,8 +55,36 @@ TEST_MULTIPLE_BATCH_PROMPTS = [
|
|||||||
|
|
||||||
|
|
||||||
class TestLoRA(CustomTestCase):
|
class TestLoRA(CustomTestCase):
|
||||||
|
def _create_test_samples(
|
||||||
|
self, lora_adapter_paths: List[str], repeated_trials: int = 3
|
||||||
|
):
|
||||||
|
random.seed(42) # Ensure reproducibility
|
||||||
|
|
||||||
|
patterns = [
|
||||||
|
[None, lora_adapter_paths[0], lora_adapter_paths[1]],
|
||||||
|
[lora_adapter_paths[0], None, lora_adapter_paths[1]],
|
||||||
|
[lora_adapter_paths[0], lora_adapter_paths[1], None],
|
||||||
|
[None, lora_adapter_paths[1], None],
|
||||||
|
[None, None, None],
|
||||||
|
]
|
||||||
|
|
||||||
|
batches = [
|
||||||
|
[random.choice(pattern) for _ in range(3)]
|
||||||
|
for pattern in patterns
|
||||||
|
for _ in range(repeated_trials)
|
||||||
|
]
|
||||||
|
|
||||||
|
return batches
|
||||||
|
|
||||||
|
def ensure_reproducibility(self):
|
||||||
|
seed = 42
|
||||||
|
random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
torch.use_deterministic_algorithms(True)
|
||||||
|
|
||||||
def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCase]):
|
def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCase]):
|
||||||
|
|
||||||
for model_case in model_cases:
|
for model_case in model_cases:
|
||||||
for torch_dtype in TORCH_DTYPES:
|
for torch_dtype in TORCH_DTYPES:
|
||||||
max_new_tokens = 32
|
max_new_tokens = 32
|
||||||
@@ -64,57 +93,6 @@ class TestLoRA(CustomTestCase):
|
|||||||
lora_adapter_paths = [a.name for a in model_case.adaptors]
|
lora_adapter_paths = [a.name for a in model_case.adaptors]
|
||||||
assert len(lora_adapter_paths) >= 2
|
assert len(lora_adapter_paths) >= 2
|
||||||
|
|
||||||
batches = [
|
|
||||||
(
|
|
||||||
[
|
|
||||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
|
||||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
|
||||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
None,
|
|
||||||
lora_adapter_paths[0],
|
|
||||||
lora_adapter_paths[1],
|
|
||||||
],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[
|
|
||||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
|
||||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
|
||||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
lora_adapter_paths[0],
|
|
||||||
None,
|
|
||||||
lora_adapter_paths[1],
|
|
||||||
],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[
|
|
||||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
|
||||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
|
||||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
|
||||||
],
|
|
||||||
[lora_adapter_paths[0], lora_adapter_paths[1], None],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[
|
|
||||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
|
||||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
|
||||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
|
||||||
],
|
|
||||||
[None, lora_adapter_paths[1], None],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[
|
|
||||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
|
||||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
|
||||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
|
||||||
],
|
|
||||||
[None, None, None],
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"\n========== Testing multiple batches on base '{base_path}' with backend={backend}, dtype={torch_dtype} ---"
|
f"\n========== Testing multiple batches on base '{base_path}' with backend={backend}, dtype={torch_dtype} ---"
|
||||||
)
|
)
|
||||||
@@ -128,23 +106,31 @@ class TestLoRA(CustomTestCase):
|
|||||||
max_loras_per_batch=len(lora_adapter_paths) + 1,
|
max_loras_per_batch=len(lora_adapter_paths) + 1,
|
||||||
lora_backend=backend,
|
lora_backend=backend,
|
||||||
disable_radix_cache=True,
|
disable_radix_cache=True,
|
||||||
|
sleep_on_idle=True, # Eliminate non-determinism by forcing all requests to be processed in one batch.
|
||||||
|
attention_backend="torch_native",
|
||||||
)
|
)
|
||||||
hf_runner = HFRunner(
|
hf_runner = HFRunner(
|
||||||
base_path, torch_dtype=torch_dtype, model_type="generation"
|
base_path, torch_dtype=torch_dtype, model_type="generation"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
batches = self._create_test_samples(lora_adapter_paths)
|
||||||
with srt_runner, hf_runner:
|
with srt_runner, hf_runner:
|
||||||
for i, (prompts, lora_paths) in enumerate(batches):
|
for i, lora_paths in enumerate(batches, start=1):
|
||||||
|
prompts = [
|
||||||
|
random.choice(TEST_MULTIPLE_BATCH_PROMPTS) for _ in range(3)
|
||||||
|
]
|
||||||
print(
|
print(
|
||||||
f"\n--- Running Batch {i+1} --- prompts: {prompts}, lora_paths: {lora_paths}"
|
f"\n--- Running Batch {i} --- prompts: {prompts}, lora_paths: {lora_paths}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.ensure_reproducibility()
|
||||||
srt_outputs = srt_runner.batch_forward(
|
srt_outputs = srt_runner.batch_forward(
|
||||||
prompts,
|
prompts,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
lora_paths=lora_paths,
|
lora_paths=lora_paths,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.ensure_reproducibility()
|
||||||
hf_outputs = hf_runner.forward(
|
hf_outputs = hf_runner.forward(
|
||||||
prompts,
|
prompts,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
@@ -167,7 +153,7 @@ class TestLoRA(CustomTestCase):
|
|||||||
f"for base '{base_path}', adaptor '{lora_paths}', backend '{backend}', prompt: '{prompts}...'"
|
f"for base '{base_path}', adaptor '{lora_paths}', backend '{backend}', prompt: '{prompts}...'"
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"--- Batch {i+1} Comparison Passed --- ")
|
print(f"--- Batch {i} Comparison Passed --- ")
|
||||||
|
|
||||||
def test_ci_lora_models(self):
|
def test_ci_lora_models(self):
|
||||||
self._run_lora_multiple_batch_on_model_cases(CI_MULTI_LORA_MODELS)
|
self._run_lora_multiple_batch_on_model_cases(CI_MULTI_LORA_MODELS)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ class TestFile:
|
|||||||
|
|
||||||
suites = {
|
suites = {
|
||||||
"per-commit": [
|
"per-commit": [
|
||||||
TestFile("models/lora/test_lora.py", 76),
|
TestFile("models/lora/test_lora.py", 200),
|
||||||
TestFile("models/lora/test_lora_backend.py", 99),
|
TestFile("models/lora/test_lora_backend.py", 99),
|
||||||
TestFile("models/lora/test_multi_lora_backend.py", 60),
|
TestFile("models/lora/test_multi_lora_backend.py", 60),
|
||||||
TestFile("models/lora/test_lora_cuda_graph.py", 250),
|
TestFile("models/lora/test_lora_cuda_graph.py", 250),
|
||||||
|
|||||||
Reference in New Issue
Block a user