Make input_ids a torch.Tensor (#1568)
This commit is contained in:
@@ -514,9 +514,10 @@ class ScheduleBatch:
|
|||||||
pt += req.extend_input_len
|
pt += req.extend_input_len
|
||||||
|
|
||||||
# Set fields
|
# Set fields
|
||||||
self.input_ids = sum(input_ids, [])
|
with out_cache_loc.device:
|
||||||
self.req_pool_indices = torch.tensor(req_pool_indices, device="cuda")
|
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
|
||||||
self.seq_lens = torch.tensor(seq_lens, device="cuda")
|
self.req_pool_indices = torch.tensor(req_pool_indices)
|
||||||
|
self.seq_lens = torch.tensor(seq_lens)
|
||||||
|
|
||||||
self.extend_num_tokens = extend_num_tokens
|
self.extend_num_tokens = extend_num_tokens
|
||||||
self.out_cache_loc = out_cache_loc
|
self.out_cache_loc = out_cache_loc
|
||||||
@@ -536,7 +537,7 @@ class ScheduleBatch:
|
|||||||
req.fill_ids = req.origin_input_ids + req.output_ids
|
req.fill_ids = req.origin_input_ids + req.output_ids
|
||||||
req.extend_input_len = 1
|
req.extend_input_len = 1
|
||||||
|
|
||||||
input_ids = self.input_ids + running_batch.input_ids
|
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
|
||||||
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
|
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
|
||||||
extend_num_tokens = self.extend_num_tokens + running_bs
|
extend_num_tokens = self.extend_num_tokens + running_bs
|
||||||
|
|
||||||
@@ -722,7 +723,9 @@ class ScheduleBatch:
|
|||||||
for r in self.reqs
|
for r in self.reqs
|
||||||
]
|
]
|
||||||
|
|
||||||
self.input_ids = input_ids
|
self.input_ids = torch.tensor(
|
||||||
|
input_ids, dtype=torch.int32, device=self.seq_lens.device
|
||||||
|
)
|
||||||
self.seq_lens.add_(1)
|
self.seq_lens.add_(1)
|
||||||
|
|
||||||
# Alloc mem
|
# Alloc mem
|
||||||
@@ -824,7 +827,7 @@ class ModelWorkerBatch:
|
|||||||
# The forward mode
|
# The forward mode
|
||||||
forward_mode: ForwardMode
|
forward_mode: ForwardMode
|
||||||
# The input ids
|
# The input ids
|
||||||
input_ids: List[int]
|
input_ids: torch.Tensor
|
||||||
# The indices of requests in the req_to_token_pool
|
# The indices of requests in the req_to_token_pool
|
||||||
req_pool_indices: torch.Tensor
|
req_pool_indices: torch.Tensor
|
||||||
# The sequence length
|
# The sequence length
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ class ReqToTokenPool:
|
|||||||
def __init__(self, size: int, max_context_len: int, device: str):
|
def __init__(self, size: int, max_context_len: int, device: str):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.max_context_len = max_context_len
|
self.max_context_len = max_context_len
|
||||||
|
self.device = device
|
||||||
self.free_slots = list(range(size))
|
self.free_slots = list(range(size))
|
||||||
self.req_to_token = torch.empty(
|
self.req_to_token = torch.empty(
|
||||||
(size, max_context_len), dtype=torch.int32, device=device
|
(size, max_context_len), dtype=torch.int32, device=device
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ class ForwardBatch:
|
|||||||
ret = cls(
|
ret = cls(
|
||||||
forward_mode=batch.forward_mode,
|
forward_mode=batch.forward_mode,
|
||||||
batch_size=len(batch.seq_lens),
|
batch_size=len(batch.seq_lens),
|
||||||
input_ids=torch.tensor(batch.input_ids, dtype=torch.int32, device=device),
|
input_ids=batch.input_ids,
|
||||||
req_pool_indices=batch.req_pool_indices,
|
req_pool_indices=batch.req_pool_indices,
|
||||||
seq_lens=batch.seq_lens,
|
seq_lens=batch.seq_lens,
|
||||||
out_cache_loc=batch.out_cache_loc,
|
out_cache_loc=batch.out_cache_loc,
|
||||||
|
|||||||
Reference in New Issue
Block a user