diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 73e122a75..caeaa7736 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -137,9 +137,6 @@ class RadixAttention(nn.Module): def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata): key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id) + key_buffer[input_metadata.out_cache_loc] = cache_k value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id) - if input_metadata.out_cache_loc is not None: - key_buffer[input_metadata.out_cache_loc] = cache_k - value_buffer[input_metadata.out_cache_loc] = cache_v - else: - raise RuntimeError() + value_buffer[input_metadata.out_cache_loc] = cache_v diff --git a/python/sglang/srt/managers/controller/cuda_graph_runner.py b/python/sglang/srt/managers/controller/cuda_graph_runner.py index ad3225aa6..2e37e55b5 100644 --- a/python/sglang/srt/managers/controller/cuda_graph_runner.py +++ b/python/sglang/srt/managers/controller/cuda_graph_runner.py @@ -132,7 +132,8 @@ class CudaGraphRunner: index = bisect.bisect_left(self.batch_size_list, raw_bs) bs = self.batch_size_list[index] if bs != raw_bs: - self.seq_lens.fill_(1) + self.seq_lens.zero_() + self.position_ids_offsets.fill_(1) self.out_cache_loc.zero_() # Common inputs @@ -168,4 +169,4 @@ class CudaGraphRunner: prefill_top_logprobs=None, decode_top_logprobs=output.decode_top_logprobs[:raw_bs] if output.decode_top_logprobs is not None else None, ) - return output \ No newline at end of file + return output diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index 89cd851a0..b572e120e 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -315,7 +315,7 @@ class ModelTpServer: def get_new_fill_batch(self) -> Optional[Batch]: running_bs = len(self.running_batch.reqs) if self.running_batch is not None else 0 - if running_bs > self.max_running_requests: + if running_bs >= self.max_running_requests: return # Compute matched prefix length @@ -393,7 +393,7 @@ class ModelTpServer: else: break - if running_bs + len(can_run_list) > self.max_running_requests: + if running_bs + len(can_run_list) >= self.max_running_requests: break if len(can_run_list) == 0: diff --git a/python/sglang/srt/memory_pool.py b/python/sglang/srt/memory_pool.py index f5b032218..245e6ef08 100644 --- a/python/sglang/srt/memory_pool.py +++ b/python/sglang/srt/memory_pool.py @@ -46,7 +46,7 @@ class TokenToKVPool: # [size, key/value, head_num, head_dim] for each layer self.kv_data = [ - torch.empty((size, 2, head_num, head_dim), dtype=dtype, device="cuda") + torch.empty((size + 1, 2, head_num, head_dim), dtype=dtype, device="cuda") for _ in range(layer_num) ] @@ -127,4 +127,4 @@ class TokenToKVPool: self.total_ref_ct = 0 # We also add one slot. This slot is used for writing dummy output from padded tokens. - self.add_refs(torch.tensor([0], dtype=torch.int32)) \ No newline at end of file + self.add_refs(torch.tensor([0], dtype=torch.int32))