Unify the memory pool api and tp worker API (#1724)

This commit is contained in:
Lianmin Zheng
2024-10-19 23:19:26 -07:00
committed by GitHub
parent 95946271af
commit 59cbf47626
8 changed files with 87 additions and 25 deletions

View File

@@ -78,7 +78,7 @@ class SamplingBatchInfo:
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
is_all_greedy=top_ks.max().item() <= 1,
vocab_size=vocab_size,
device=batch.input_ids.device,
device=device,
)
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
@@ -224,3 +224,13 @@ class SamplingBatchInfo:
vocab_size=self.vocab_size,
device=self.device,
)
def to(self, device: str):
for item in [
"temperatures",
"top_ps",
"top_ks",
"min_ps",
]:
value = getattr(self, item)
setattr(self, item, value.to(device, non_blocking=True))