Unify the memory pool api and tp worker API (#1724)
This commit is contained in:
@@ -78,7 +78,7 @@ class SamplingBatchInfo:
|
||||
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
||||
is_all_greedy=top_ks.max().item() <= 1,
|
||||
vocab_size=vocab_size,
|
||||
device=batch.input_ids.device,
|
||||
device=device,
|
||||
)
|
||||
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
||||
|
||||
@@ -224,3 +224,13 @@ class SamplingBatchInfo:
|
||||
vocab_size=self.vocab_size,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def to(self, device: str):
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"min_ps",
|
||||
]:
|
||||
value = getattr(self, item)
|
||||
setattr(self, item, value.to(device, non_blocking=True))
|
||||
|
||||
Reference in New Issue
Block a user