diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index d5e400fbd..64cacd4c2 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -514,9 +514,10 @@ class ScheduleBatch: pt += req.extend_input_len # Set fields - self.input_ids = sum(input_ids, []) - self.req_pool_indices = torch.tensor(req_pool_indices, device="cuda") - self.seq_lens = torch.tensor(seq_lens, device="cuda") + with out_cache_loc.device: + self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32) + self.req_pool_indices = torch.tensor(req_pool_indices) + self.seq_lens = torch.tensor(seq_lens) self.extend_num_tokens = extend_num_tokens self.out_cache_loc = out_cache_loc @@ -536,7 +537,7 @@ class ScheduleBatch: req.fill_ids = req.origin_input_ids + req.output_ids req.extend_input_len = 1 - input_ids = self.input_ids + running_batch.input_ids + input_ids = torch.cat([self.input_ids, running_batch.input_ids]) out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc]) extend_num_tokens = self.extend_num_tokens + running_bs @@ -722,7 +723,9 @@ class ScheduleBatch: for r in self.reqs ] - self.input_ids = input_ids + self.input_ids = torch.tensor( + input_ids, dtype=torch.int32, device=self.seq_lens.device + ) self.seq_lens.add_(1) # Alloc mem @@ -824,7 +827,7 @@ class ModelWorkerBatch: # The forward mode forward_mode: ForwardMode # The input ids - input_ids: List[int] + input_ids: torch.Tensor # The indices of requests in the req_to_token_pool req_pool_indices: torch.Tensor # The sequence length diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 3bf96d381..a4c90be1e 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -30,6 +30,7 @@ class ReqToTokenPool: def __init__(self, size: int, max_context_len: int, device: str): self.size = size self.max_context_len = max_context_len + self.device = device self.free_slots = list(range(size)) self.req_to_token = torch.empty( (size, max_context_len), dtype=torch.int32, device=device diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 0fdf300cd..d76b981d5 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -123,7 +123,7 @@ class ForwardBatch: ret = cls( forward_mode=batch.forward_mode, batch_size=len(batch.seq_lens), - input_ids=torch.tensor(batch.input_ids, dtype=torch.int32, device=device), + input_ids=batch.input_ids, req_pool_indices=batch.req_pool_indices, seq_lens=batch.seq_lens, out_cache_loc=batch.out_cache_loc,