Eliminate 2 gpu ops during sampling when logit_bias is zero (#338)
Co-authored-by: hnyls2002 <hnyls2002@gmail.com>
This commit is contained in:
@@ -251,10 +251,14 @@ class Batch:
|
|||||||
] = out_cache_loc[pt : pt + extend_lens[i]]
|
] = out_cache_loc[pt : pt + extend_lens[i]]
|
||||||
pt += extend_lens[i]
|
pt += extend_lens[i]
|
||||||
|
|
||||||
# Handle logit bias
|
# Handle logit bias but only allocate when needed
|
||||||
logit_bias = torch.zeros((bs, vocab_size), dtype=torch.float32, device=device)
|
logit_bias = None
|
||||||
for i in range(bs):
|
for i in range(bs):
|
||||||
if reqs[i].sampling_params.dtype == "int":
|
if reqs[i].sampling_params.dtype == "int":
|
||||||
|
if logit_bias is None:
|
||||||
|
logit_bias = torch.zeros(
|
||||||
|
(bs, vocab_size), dtype=torch.float32, device=device
|
||||||
|
)
|
||||||
logit_bias[i] = int_token_logit_bias
|
logit_bias[i] = int_token_logit_bias
|
||||||
|
|
||||||
# Set fields
|
# Set fields
|
||||||
@@ -433,9 +437,12 @@ class Batch:
|
|||||||
"presence_penalties",
|
"presence_penalties",
|
||||||
"logit_bias",
|
"logit_bias",
|
||||||
]:
|
]:
|
||||||
setattr(self, item, getattr(self, item)[new_indices])
|
self_val = getattr(self, item, None)
|
||||||
|
# logit_bias can be None
|
||||||
|
if self_val is not None:
|
||||||
|
setattr(self, item, self_val[new_indices])
|
||||||
|
|
||||||
def merge(self, other):
|
def merge(self, other: "Batch"):
|
||||||
self.reqs.extend(other.reqs)
|
self.reqs.extend(other.reqs)
|
||||||
|
|
||||||
self.req_pool_indices = torch.concat(
|
self.req_pool_indices = torch.concat(
|
||||||
@@ -456,17 +463,34 @@ class Batch:
|
|||||||
"top_ks",
|
"top_ks",
|
||||||
"frequency_penalties",
|
"frequency_penalties",
|
||||||
"presence_penalties",
|
"presence_penalties",
|
||||||
"logit_bias",
|
|
||||||
]:
|
]:
|
||||||
setattr(
|
self_val = getattr(self, item, None)
|
||||||
self, item, torch.concat([getattr(self, item), getattr(other, item)])
|
other_val = getattr(other, item, None)
|
||||||
|
setattr(self, item, torch.concat([self_val, other_val]))
|
||||||
|
|
||||||
|
# logit_bias can be None
|
||||||
|
if self.logit_bias is not None or other.logit_bias is not None:
|
||||||
|
vocab_size = (
|
||||||
|
self.logit_bias.shape[1]
|
||||||
|
if self.logit_bias is not None
|
||||||
|
else other.logit_bias.shape[1]
|
||||||
)
|
)
|
||||||
|
if self.logit_bias is None:
|
||||||
|
self.logit_bias = torch.zeros(
|
||||||
|
(len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
||||||
|
)
|
||||||
|
if other.logit_bias is None:
|
||||||
|
other.logit_bias = torch.zeros(
|
||||||
|
(len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
||||||
|
)
|
||||||
|
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
||||||
|
|
||||||
def sample(self, logits: torch.Tensor):
|
def sample(self, logits: torch.Tensor):
|
||||||
# Post process logits
|
# Post process logits
|
||||||
logits = logits.contiguous()
|
logits = logits.contiguous()
|
||||||
logits.div_(self.temperatures)
|
logits.div_(self.temperatures)
|
||||||
logits.add_(self.logit_bias)
|
if self.logit_bias is not None:
|
||||||
|
logits.add_(self.logit_bias)
|
||||||
|
|
||||||
has_regex = any(req.regex_fsm is not None for req in self.reqs)
|
has_regex = any(req.regex_fsm is not None for req in self.reqs)
|
||||||
if has_regex:
|
if has_regex:
|
||||||
|
|||||||
Reference in New Issue
Block a user