[Fix, LoRA] fix LoRA with updates in main (#1545)

This commit is contained in:
Ying Sheng
2024-09-30 10:06:08 -07:00
committed by GitHub
parent 63ba2f8d7b
commit 0f4fb19bc8
5 changed files with 31 additions and 23 deletions

View File

@@ -97,9 +97,7 @@ class TestLoRA(unittest.TestCase):
)
with HFRunner(
base_path,
torch_dtype=torch_dtype,
is_generation=True,
base_path, torch_dtype=torch_dtype, model_type="generation"
) as hf_runner:
hf_outputs = hf_runner.forward(
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
@@ -108,7 +106,7 @@ class TestLoRA(unittest.TestCase):
with HFRunner(
base_path,
torch_dtype=torch_dtype,
is_generation=True,
model_type="generation",
) as hf_runner:
hf_no_lora_outputs = hf_runner.forward(
prompts, max_new_tokens=max_new_tokens
@@ -118,7 +116,7 @@ class TestLoRA(unittest.TestCase):
base_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
is_generation=True,
model_type="generation",
) as srt_runner:
srt_no_lora_outputs = srt_runner.forward(
prompts, max_new_tokens=max_new_tokens
@@ -198,7 +196,7 @@ class TestLoRA(unittest.TestCase):
base_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
is_generation=True,
model_type="generation",
lora_paths=all_lora_paths,
max_loras_per_batch=3,
disable_cuda_graph=True,
@@ -211,7 +209,7 @@ class TestLoRA(unittest.TestCase):
with HFRunner(
base_path,
torch_dtype=torch_dtype,
is_generation=True,
model_type="generation",
output_str_only=True,
) as hf_runner:
hf_outputs = hf_runner.forward(
@@ -237,7 +235,7 @@ class TestLoRA(unittest.TestCase):
base_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
is_generation=True,
model_type="generation",
) as srt_runner:
srt_no_lora_outputs = srt_runner.forward(
prompts, max_new_tokens=max_new_tokens
@@ -247,7 +245,7 @@ class TestLoRA(unittest.TestCase):
base_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
is_generation=True,
model_type="generation",
lora_paths=all_lora_paths,
) as srt_runner:
srt_outputs = srt_runner.forward(

View File

@@ -7,7 +7,7 @@ suites = {
"minimal": [
"models/test_embedding_models.py",
"models/test_generation_models.py",
# "models/test_lora.py",
"models/test_lora.py",
"models/test_reward_models.py",
"sampling/penaltylib",
"test_chunked_prefill.py",