Fix potential flakiness in test_lora_qwen3 (#10250)
This commit is contained in:
@@ -24,6 +24,7 @@ from utils import (
|
|||||||
CI_MULTI_LORA_MODELS,
|
CI_MULTI_LORA_MODELS,
|
||||||
TORCH_DTYPES,
|
TORCH_DTYPES,
|
||||||
LoRAModelCase,
|
LoRAModelCase,
|
||||||
|
ensure_reproducibility,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sglang.test.runners import HFRunner, SRTRunner
|
from sglang.test.runners import HFRunner, SRTRunner
|
||||||
@@ -76,13 +77,6 @@ class TestLoRA(CustomTestCase):
|
|||||||
|
|
||||||
return batches
|
return batches
|
||||||
|
|
||||||
def ensure_reproducibility(self):
|
|
||||||
seed = 42
|
|
||||||
random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
torch.cuda.manual_seed_all(seed)
|
|
||||||
torch.use_deterministic_algorithms(True)
|
|
||||||
|
|
||||||
def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCase]):
|
def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCase]):
|
||||||
for model_case in model_cases:
|
for model_case in model_cases:
|
||||||
for torch_dtype in TORCH_DTYPES:
|
for torch_dtype in TORCH_DTYPES:
|
||||||
@@ -121,14 +115,14 @@ class TestLoRA(CustomTestCase):
|
|||||||
f"\n--- Running Batch {i} --- prompts: {prompts}, lora_paths: {lora_paths}"
|
f"\n--- Running Batch {i} --- prompts: {prompts}, lora_paths: {lora_paths}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ensure_reproducibility()
|
ensure_reproducibility()
|
||||||
srt_outputs = srt_runner.batch_forward(
|
srt_outputs = srt_runner.batch_forward(
|
||||||
prompts,
|
prompts,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
lora_paths=lora_paths,
|
lora_paths=lora_paths,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ensure_reproducibility()
|
ensure_reproducibility()
|
||||||
hf_outputs = hf_runner.forward(
|
hf_outputs = hf_runner.forward(
|
||||||
prompts,
|
prompts,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import random
|
|||||||
import unittest
|
import unittest
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from utils import TORCH_DTYPES, LoRAAdaptor, LoRAModelCase
|
from utils import TORCH_DTYPES, LoRAAdaptor, LoRAModelCase, ensure_reproducibility
|
||||||
|
|
||||||
from sglang.test.runners import HFRunner, SRTRunner
|
from sglang.test.runners import HFRunner, SRTRunner
|
||||||
from sglang.test.test_utils import CustomTestCase, calculate_rouge_l, is_in_ci
|
from sglang.test.test_utils import CustomTestCase, calculate_rouge_l, is_in_ci
|
||||||
@@ -59,19 +59,18 @@ TEST_MULTIPLE_BATCH_PROMPTS = [
|
|||||||
The Transformers are large language models,
|
The Transformers are large language models,
|
||||||
They're used to make predictions on text.
|
They're used to make predictions on text.
|
||||||
""",
|
""",
|
||||||
# "AI is a field of computer science focused on", TODO: Add it back after fixing its bug
|
"AI is a field of computer science focused on",
|
||||||
"Computer science is the study of",
|
"Computer science is the study of",
|
||||||
"Write a short story.",
|
"Write a short story.",
|
||||||
"What are the main components of a computer?",
|
"What are the main components of a computer?",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class TestLoRA(CustomTestCase):
|
class TestLoRAQwen3(CustomTestCase):
|
||||||
|
|
||||||
def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCase]):
|
def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCase]):
|
||||||
for model_case in model_cases:
|
for model_case in model_cases:
|
||||||
for torch_dtype in TORCH_DTYPES:
|
for torch_dtype in TORCH_DTYPES:
|
||||||
max_new_tokens = 10
|
max_new_tokens = 32
|
||||||
backend = "triton"
|
backend = "triton"
|
||||||
base_path = model_case.base
|
base_path = model_case.base
|
||||||
lora_adapter_paths = [a.name for a in model_case.adaptors]
|
lora_adapter_paths = [a.name for a in model_case.adaptors]
|
||||||
@@ -133,6 +132,7 @@ class TestLoRA(CustomTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Initialize runners
|
# Initialize runners
|
||||||
|
ensure_reproducibility()
|
||||||
srt_runner = SRTRunner(
|
srt_runner = SRTRunner(
|
||||||
base_path,
|
base_path,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
@@ -140,7 +140,11 @@ class TestLoRA(CustomTestCase):
|
|||||||
lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]],
|
lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]],
|
||||||
max_loras_per_batch=len(lora_adapter_paths) + 1,
|
max_loras_per_batch=len(lora_adapter_paths) + 1,
|
||||||
lora_backend=backend,
|
lora_backend=backend,
|
||||||
|
sleep_on_idle=True, # Eliminate non-determinism by forcing all requests to be processed in one batch.
|
||||||
|
attention_backend="torch_native",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ensure_reproducibility()
|
||||||
hf_runner = HFRunner(
|
hf_runner = HFRunner(
|
||||||
base_path,
|
base_path,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import random
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -386,3 +387,11 @@ def run_lora_test_by_batch(
|
|||||||
srt_no_lora_outputs.output_strs[i].strip(" "),
|
srt_no_lora_outputs.output_strs[i].strip(" "),
|
||||||
hf_no_lora_outputs.output_strs[i].strip(" "),
|
hf_no_lora_outputs.output_strs[i].strip(" "),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_reproducibility():
|
||||||
|
seed = 42
|
||||||
|
random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
torch.use_deterministic_algorithms(True)
|
||||||
|
|||||||
Reference in New Issue
Block a user