[Fix, LoRA] fix LoRA with updates in main (#1545)
This commit is contained in:
@@ -97,9 +97,7 @@ class TestLoRA(unittest.TestCase):
|
||||
)
|
||||
|
||||
with HFRunner(
|
||||
base_path,
|
||||
torch_dtype=torch_dtype,
|
||||
is_generation=True,
|
||||
base_path, torch_dtype=torch_dtype, model_type="generation"
|
||||
) as hf_runner:
|
||||
hf_outputs = hf_runner.forward(
|
||||
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
|
||||
@@ -108,7 +106,7 @@ class TestLoRA(unittest.TestCase):
|
||||
with HFRunner(
|
||||
base_path,
|
||||
torch_dtype=torch_dtype,
|
||||
is_generation=True,
|
||||
model_type="generation",
|
||||
) as hf_runner:
|
||||
hf_no_lora_outputs = hf_runner.forward(
|
||||
prompts, max_new_tokens=max_new_tokens
|
||||
@@ -118,7 +116,7 @@ class TestLoRA(unittest.TestCase):
|
||||
base_path,
|
||||
tp_size=tp_size,
|
||||
torch_dtype=torch_dtype,
|
||||
is_generation=True,
|
||||
model_type="generation",
|
||||
) as srt_runner:
|
||||
srt_no_lora_outputs = srt_runner.forward(
|
||||
prompts, max_new_tokens=max_new_tokens
|
||||
@@ -198,7 +196,7 @@ class TestLoRA(unittest.TestCase):
|
||||
base_path,
|
||||
tp_size=tp_size,
|
||||
torch_dtype=torch_dtype,
|
||||
is_generation=True,
|
||||
model_type="generation",
|
||||
lora_paths=all_lora_paths,
|
||||
max_loras_per_batch=3,
|
||||
disable_cuda_graph=True,
|
||||
@@ -211,7 +209,7 @@ class TestLoRA(unittest.TestCase):
|
||||
with HFRunner(
|
||||
base_path,
|
||||
torch_dtype=torch_dtype,
|
||||
is_generation=True,
|
||||
model_type="generation",
|
||||
output_str_only=True,
|
||||
) as hf_runner:
|
||||
hf_outputs = hf_runner.forward(
|
||||
@@ -237,7 +235,7 @@ class TestLoRA(unittest.TestCase):
|
||||
base_path,
|
||||
tp_size=tp_size,
|
||||
torch_dtype=torch_dtype,
|
||||
is_generation=True,
|
||||
model_type="generation",
|
||||
) as srt_runner:
|
||||
srt_no_lora_outputs = srt_runner.forward(
|
||||
prompts, max_new_tokens=max_new_tokens
|
||||
@@ -247,7 +245,7 @@ class TestLoRA(unittest.TestCase):
|
||||
base_path,
|
||||
tp_size=tp_size,
|
||||
torch_dtype=torch_dtype,
|
||||
is_generation=True,
|
||||
model_type="generation",
|
||||
lora_paths=all_lora_paths,
|
||||
) as srt_runner:
|
||||
srt_outputs = srt_runner.forward(
|
||||
|
||||
@@ -7,7 +7,7 @@ suites = {
|
||||
"minimal": [
|
||||
"models/test_embedding_models.py",
|
||||
"models/test_generation_models.py",
|
||||
# "models/test_lora.py",
|
||||
"models/test_lora.py",
|
||||
"models/test_reward_models.py",
|
||||
"sampling/penaltylib",
|
||||
"test_chunked_prefill.py",
|
||||
|
||||
Reference in New Issue
Block a user