Add device support (#1607)

This commit is contained in:
Zhang, Liangang
2024-10-11 17:05:58 +08:00
committed by GitHub
parent 5476ccad8f
commit 8275049ce3
5 changed files with 96 additions and 52 deletions

View File

@@ -37,6 +37,9 @@ class SamplingBatchInfo:
linear_penalties: torch.Tensor = None
scaling_penalties: torch.Tensor = None
# Device
device: str = "cuda"
@classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
reqs = batch.reqs
@@ -62,6 +65,7 @@ class SamplingBatchInfo:
min_ps=min_ps,
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
vocab_size=vocab_size,
device=batch.input_ids.device,
)
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
@@ -75,7 +79,7 @@ class SamplingBatchInfo:
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
vocab_size=vocab_size,
batch=batch,
device="cuda",
device=batch.input_ids.device,
Penalizers={
penaltylib.BatchedFrequencyPenalizer,
penaltylib.BatchedMinNewTokensPenalizer,
@@ -107,7 +111,7 @@ class SamplingBatchInfo:
self.linear_penalties = torch.zeros(
(bs, self.vocab_size),
dtype=torch.float32,
device="cuda",
device=self.device,
)
self.linear_penalties = penalizer.apply(self.linear_penalties)
@@ -119,7 +123,10 @@ class SamplingBatchInfo:
if has_regex:
self.vocab_mask = torch.zeros(
len(self.temperatures), self.vocab_size, dtype=torch.bool, device="cuda"
len(self.temperatures),
self.vocab_size,
dtype=torch.bool,
device=self.device,
)
for i, regex_fsm in enumerate(self.regex_fsms):
if regex_fsm is not None:
@@ -144,7 +151,12 @@ class SamplingBatchInfo:
@staticmethod
def merge_bias_tensor(
lhs: torch.Tensor, rhs: torch.Tensor, bs1: int, bs2: int, default: int = 0
lhs: torch.Tensor,
rhs: torch.Tensor,
bs1: int,
bs2: int,
device: str,
default: int = 0,
):
# bias tensor can be None
if lhs is not None or rhs is not None:
@@ -155,9 +167,9 @@ class SamplingBatchInfo:
shape, dtype = rhs.shape[1:], rhs.dtype
with torch.dtype(dtype):
if lhs is None:
lhs = torch.empty((bs1, *shape), device="cuda").fill_(default)
lhs = torch.empty((bs1, *shape), device=device).fill_(default)
if rhs is None:
rhs = torch.empty((bs2, *shape), device="cuda").fill_(default)
rhs = torch.empty((bs2, *shape), device=device).fill_(default)
return torch.cat([lhs, rhs])
return None
@@ -176,5 +188,5 @@ class SamplingBatchInfo:
setattr(self, item, torch.concat([self_val, other_val]))
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other)
self.logit_bias, other.logit_bias, len(self), len(other), self.device
)