Add device support (#1607)
This commit is contained in:
@@ -37,6 +37,9 @@ class SamplingBatchInfo:
|
||||
linear_penalties: torch.Tensor = None
|
||||
scaling_penalties: torch.Tensor = None
|
||||
|
||||
# Device
|
||||
device: str = "cuda"
|
||||
|
||||
@classmethod
|
||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||
reqs = batch.reqs
|
||||
@@ -62,6 +65,7 @@ class SamplingBatchInfo:
|
||||
min_ps=min_ps,
|
||||
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
||||
vocab_size=vocab_size,
|
||||
device=batch.input_ids.device,
|
||||
)
|
||||
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
||||
|
||||
@@ -75,7 +79,7 @@ class SamplingBatchInfo:
|
||||
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
||||
vocab_size=vocab_size,
|
||||
batch=batch,
|
||||
device="cuda",
|
||||
device=batch.input_ids.device,
|
||||
Penalizers={
|
||||
penaltylib.BatchedFrequencyPenalizer,
|
||||
penaltylib.BatchedMinNewTokensPenalizer,
|
||||
@@ -107,7 +111,7 @@ class SamplingBatchInfo:
|
||||
self.linear_penalties = torch.zeros(
|
||||
(bs, self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
device=self.device,
|
||||
)
|
||||
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
||||
|
||||
@@ -119,7 +123,10 @@ class SamplingBatchInfo:
|
||||
|
||||
if has_regex:
|
||||
self.vocab_mask = torch.zeros(
|
||||
len(self.temperatures), self.vocab_size, dtype=torch.bool, device="cuda"
|
||||
len(self.temperatures),
|
||||
self.vocab_size,
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
)
|
||||
for i, regex_fsm in enumerate(self.regex_fsms):
|
||||
if regex_fsm is not None:
|
||||
@@ -144,7 +151,12 @@ class SamplingBatchInfo:
|
||||
|
||||
@staticmethod
|
||||
def merge_bias_tensor(
|
||||
lhs: torch.Tensor, rhs: torch.Tensor, bs1: int, bs2: int, default: int = 0
|
||||
lhs: torch.Tensor,
|
||||
rhs: torch.Tensor,
|
||||
bs1: int,
|
||||
bs2: int,
|
||||
device: str,
|
||||
default: int = 0,
|
||||
):
|
||||
# bias tensor can be None
|
||||
if lhs is not None or rhs is not None:
|
||||
@@ -155,9 +167,9 @@ class SamplingBatchInfo:
|
||||
shape, dtype = rhs.shape[1:], rhs.dtype
|
||||
with torch.dtype(dtype):
|
||||
if lhs is None:
|
||||
lhs = torch.empty((bs1, *shape), device="cuda").fill_(default)
|
||||
lhs = torch.empty((bs1, *shape), device=device).fill_(default)
|
||||
if rhs is None:
|
||||
rhs = torch.empty((bs2, *shape), device="cuda").fill_(default)
|
||||
rhs = torch.empty((bs2, *shape), device=device).fill_(default)
|
||||
return torch.cat([lhs, rhs])
|
||||
|
||||
return None
|
||||
@@ -176,5 +188,5 @@ class SamplingBatchInfo:
|
||||
setattr(self, item, torch.concat([self_val, other_val]))
|
||||
|
||||
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
||||
self.logit_bias, other.logit_bias, len(self), len(other)
|
||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user