Optimize conflicts between CUDA graph and vocab mask tensors (#1392)
This commit is contained in:
@@ -35,21 +35,6 @@ class Sampler(CustomOp):
|
||||
self.forward_native = self.forward_cuda
|
||||
self.is_torch_compile = False
|
||||
|
||||
def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
||||
# min-token, presence, frequency
|
||||
if sampling_info.linear_penalties is not None:
|
||||
logits += sampling_info.linear_penalties
|
||||
|
||||
# repetition
|
||||
if sampling_info.scaling_penalties is not None:
|
||||
logits = torch.where(
|
||||
logits > 0,
|
||||
logits / sampling_info.scaling_penalties,
|
||||
logits * sampling_info.scaling_penalties,
|
||||
)
|
||||
|
||||
return logits
|
||||
|
||||
def _get_probs(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
||||
# Post process logits
|
||||
logits = logits.contiguous()
|
||||
@@ -58,14 +43,6 @@ class Sampler(CustomOp):
|
||||
# FIXME: Temporary workaround for unknown bugs in torch.compile
|
||||
logits.add_(0)
|
||||
|
||||
if sampling_info.logit_bias is not None:
|
||||
logits.add_(sampling_info.logit_bias)
|
||||
|
||||
if sampling_info.vocab_mask is not None:
|
||||
logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
|
||||
|
||||
logits = self._apply_penalties(logits, sampling_info)
|
||||
|
||||
return torch.softmax(logits, dim=-1)
|
||||
|
||||
def forward_cuda(
|
||||
|
||||
Reference in New Issue
Block a user