Skip unnecessary penalizer (#1707)
This commit is contained in:
@@ -164,19 +164,20 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
|
||||
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
|
||||
)
|
||||
|
||||
actual = orchestrator.apply(
|
||||
torch.ones(
|
||||
size=(len(case.test_subjects), self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
original = torch.ones(
|
||||
size=(len(case.test_subjects), self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
actual = orchestrator.apply(original.clone())
|
||||
expected = torch.cat(
|
||||
tensors=[
|
||||
subject.steps[0].expected_logits
|
||||
for subject in case.test_subjects
|
||||
],
|
||||
)
|
||||
if actual is None:
|
||||
actual = original
|
||||
torch.testing.assert_close(
|
||||
actual=actual,
|
||||
expected=expected,
|
||||
@@ -226,6 +227,8 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
|
||||
device=self.device,
|
||||
)
|
||||
)
|
||||
if actual_logits is None:
|
||||
continue
|
||||
filtered_expected_logits = torch.cat(
|
||||
tensors=[
|
||||
subject.steps[0].expected_logits
|
||||
@@ -317,19 +320,20 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
|
||||
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
|
||||
)
|
||||
|
||||
actual_logits = orchestrator.apply(
|
||||
torch.ones(
|
||||
size=(len(filtered_subjects), self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
original = torch.ones(
|
||||
size=(len(filtered_subjects), self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
actual_logits = orchestrator.apply(original.clone())
|
||||
filtered_expected_logits = torch.cat(
|
||||
tensors=[
|
||||
subject.steps[i].expected_logits
|
||||
for subject in filtered_subjects
|
||||
],
|
||||
)
|
||||
if actual_logits is None:
|
||||
actual_logits = original
|
||||
torch.testing.assert_close(
|
||||
actual=actual_logits,
|
||||
expected=filtered_expected_logits,
|
||||
|
||||
Reference in New Issue
Block a user