Simplify sampler and its error handling (#1441)

This commit is contained in:
Lianmin Zheng
2024-09-16 21:23:31 -07:00
committed by GitHub
parent 27b557aea7
commit 2fa5cec775
4 changed files with 32 additions and 159 deletions

View File

@@ -34,56 +34,6 @@ class SamplingBatchInfo:
linear_penalties: torch.Tensor = None
scaling_penalties: torch.Tensor = None
def __len__(self):
return len(self.temperatures)
def can_run_in_cuda_graph(self):
# Vocab bias and min_ps are not supported in CUDA graph
return (
self.logit_bias is None
and self.linear_penalties is None
and self.scaling_penalties is None
and not self.need_min_p_sampling
)
@classmethod
def dummy_one(cls, max_bs: int, vocab_size: int):
ret = cls(vocab_size=vocab_size)
with torch.device("cuda"):
ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float)
ret.top_ps = torch.ones((max_bs,), dtype=torch.float)
ret.top_ks = torch.ones((max_bs,), dtype=torch.int)
ret.vocab_mask = torch.zeros((max_bs, vocab_size), dtype=torch.bool)
return ret
def __getitem__(self, key):
if isinstance(key, slice):
# NOTE:This method is only used in CUDA graph
assert self.can_run_in_cuda_graph()
return SamplingBatchInfo(
vocab_size=self.vocab_size,
temperatures=self.temperatures[key],
top_ps=self.top_ps[key],
top_ks=self.top_ks[key],
vocab_mask=self.vocab_mask[key],
)
else:
raise NotImplementedError
def inplace_assign(self, bs: int, other: SamplingBatchInfo):
# NOTE:This method is only used in CUDA graph
assert self.can_run_in_cuda_graph()
self.vocab_size = other.vocab_size
self.temperatures[:bs] = other.temperatures
self.top_ps[:bs] = other.top_ps
self.top_ks[:bs] = other.top_ks
if other.vocab_mask is None:
self.vocab_mask[:bs].fill_(False)
else:
self.vocab_mask[:bs] = other.vocab_mask
@classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
reqs = batch.reqs
@@ -130,6 +80,9 @@ class SamplingBatchInfo:
return ret
def __len__(self):
return len(self.temperatures)
def update_penalties(self):
self.scaling_penalties = None
self.linear_penalties = None