Skip unnecessary penalizer (#1707)

This commit is contained in:
Lianmin Zheng
2024-10-18 17:54:03 -07:00
committed by GitHub
parent bc12d4033f
commit 2bcfba1b08
7 changed files with 104 additions and 75 deletions

View File

@@ -164,19 +164,20 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
)
actual = orchestrator.apply(
torch.ones(
size=(len(case.test_subjects), self.vocab_size),
dtype=torch.float32,
device=self.device,
)
original = torch.ones(
size=(len(case.test_subjects), self.vocab_size),
dtype=torch.float32,
device=self.device,
)
actual = orchestrator.apply(original.clone())
expected = torch.cat(
tensors=[
subject.steps[0].expected_logits
for subject in case.test_subjects
],
)
if actual is None:
actual = original
torch.testing.assert_close(
actual=actual,
expected=expected,
@@ -226,6 +227,8 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
device=self.device,
)
)
if actual_logits is None:
continue
filtered_expected_logits = torch.cat(
tensors=[
subject.steps[0].expected_logits
@@ -317,19 +320,20 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
)
actual_logits = orchestrator.apply(
torch.ones(
size=(len(filtered_subjects), self.vocab_size),
dtype=torch.float32,
device=self.device,
)
original = torch.ones(
size=(len(filtered_subjects), self.vocab_size),
dtype=torch.float32,
device=self.device,
)
actual_logits = orchestrator.apply(original.clone())
filtered_expected_logits = torch.cat(
tensors=[
subject.steps[i].expected_logits
for subject in filtered_subjects
],
)
if actual_logits is None:
actual_logits = original
torch.testing.assert_close(
actual=actual_logits,
expected=filtered_expected_logits,