Optimize conflicts between CUDA graph and vocab mask tensors (#1392)

This commit is contained in:
Liangsheng Yin
2024-09-13 20:27:53 -07:00
committed by GitHub
parent f3d32f888a
commit 70b6802982
32 changed files with 103 additions and 224 deletions

View File

@@ -35,21 +35,6 @@ class Sampler(CustomOp):
self.forward_native = self.forward_cuda
self.is_torch_compile = False
def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
# min-token, presence, frequency
if sampling_info.linear_penalties is not None:
logits += sampling_info.linear_penalties
# repetition
if sampling_info.scaling_penalties is not None:
logits = torch.where(
logits > 0,
logits / sampling_info.scaling_penalties,
logits * sampling_info.scaling_penalties,
)
return logits
def _get_probs(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
# Post process logits
logits = logits.contiguous()
@@ -58,14 +43,6 @@ class Sampler(CustomOp):
# FIXME: Temporary workaround for unknown bugs in torch.compile
logits.add_(0)
if sampling_info.logit_bias is not None:
logits.add_(sampling_info.logit_bias)
if sampling_info.vocab_mask is not None:
logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
logits = self._apply_penalties(logits, sampling_info)
return torch.softmax(logits, dim=-1)
def forward_cuda(