Eliminate 2 gpu ops during sampling when logit_bias is zero (#343)
Co-authored-by: Qubitium <417764+Qubitium@users.noreply.github.com>
This commit is contained in:
@@ -251,10 +251,14 @@ class Batch:
|
||||
] = out_cache_loc[pt : pt + extend_lens[i]]
|
||||
pt += extend_lens[i]
|
||||
|
||||
# Handle logit bias
|
||||
logit_bias = torch.zeros((bs, vocab_size), dtype=torch.float32, device=device)
|
||||
# Handle logit bias but only allocate when needed
|
||||
logit_bias = None
|
||||
for i in range(bs):
|
||||
if reqs[i].sampling_params.dtype == "int":
|
||||
if logit_bias is None:
|
||||
logit_bias = torch.zeros(
|
||||
(bs, vocab_size), dtype=torch.float32, device=device
|
||||
)
|
||||
logit_bias[i] = int_token_logit_bias
|
||||
|
||||
# Set fields
|
||||
@@ -433,9 +437,12 @@ class Batch:
|
||||
"presence_penalties",
|
||||
"logit_bias",
|
||||
]:
|
||||
setattr(self, item, getattr(self, item)[new_indices])
|
||||
self_val = getattr(self, item, None)
|
||||
# logit_bias can be None
|
||||
if self_val is not None:
|
||||
setattr(self, item, self_val[new_indices])
|
||||
|
||||
def merge(self, other):
|
||||
def merge(self, other: "Batch"):
|
||||
self.reqs.extend(other.reqs)
|
||||
|
||||
self.req_pool_indices = torch.concat(
|
||||
@@ -456,17 +463,34 @@ class Batch:
|
||||
"top_ks",
|
||||
"frequency_penalties",
|
||||
"presence_penalties",
|
||||
"logit_bias",
|
||||
]:
|
||||
setattr(
|
||||
self, item, torch.concat([getattr(self, item), getattr(other, item)])
|
||||
self_val = getattr(self, item, None)
|
||||
other_val = getattr(other, item, None)
|
||||
setattr(self, item, torch.concat([self_val, other_val]))
|
||||
|
||||
# logit_bias can be None
|
||||
if self.logit_bias is not None or other.logit_bias is not None:
|
||||
vocab_size = (
|
||||
self.logit_bias.shape[1]
|
||||
if self.logit_bias is not None
|
||||
else other.logit_bias.shape[1]
|
||||
)
|
||||
if self.logit_bias is None:
|
||||
self.logit_bias = torch.zeros(
|
||||
(len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
if other.logit_bias is None:
|
||||
other.logit_bias = torch.zeros(
|
||||
(len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
||||
|
||||
def sample(self, logits: torch.Tensor):
|
||||
# Post process logits
|
||||
logits = logits.contiguous()
|
||||
logits.div_(self.temperatures)
|
||||
logits.add_(self.logit_bias)
|
||||
if self.logit_bias is not None:
|
||||
logits.add_(self.logit_bias)
|
||||
|
||||
has_regex = any(req.regex_fsm is not None for req in self.reqs)
|
||||
if has_regex:
|
||||
|
||||
Reference in New Issue
Block a user