minor: some potential bugs (#1044)
This commit is contained in:
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Callable
|
from typing import TYPE_CHECKING, Callable, List, Optional
|
||||||
|
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||||
@@ -30,12 +30,13 @@ class ChunkCache(BasePrefixCache):
|
|||||||
def reset(self):
|
def reset(self):
|
||||||
self.entries = {}
|
self.entries = {}
|
||||||
|
|
||||||
def match_prefix(self, rid, **kwargs):
|
def match_prefix(self, rid: int, key: List[int]):
|
||||||
if rid not in self.entries:
|
if rid not in self.entries:
|
||||||
return [], None
|
return [], None
|
||||||
|
|
||||||
entry = self.entries[rid]
|
entry = self.entries[rid]
|
||||||
return entry.value, entry
|
max_prefix_len = len(key)
|
||||||
|
return entry.value[:max_prefix_len], entry
|
||||||
|
|
||||||
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
||||||
if token_ids is None:
|
if token_ids is None:
|
||||||
|
|||||||
@@ -140,13 +140,13 @@ class InputMetadata:
|
|||||||
if self.forward_mode == ForwardMode.DECODE:
|
if self.forward_mode == ForwardMode.DECODE:
|
||||||
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
|
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
|
||||||
else:
|
else:
|
||||||
prefix_lens_cpu = [
|
extend_lens_cpu = [
|
||||||
len(r.fill_ids) - len(r.prefix_indices) for r in batch.reqs
|
len(r.fill_ids) - len(r.prefix_indices) for r in batch.reqs
|
||||||
]
|
]
|
||||||
self.extend_seq_lens = torch.tensor(prefix_lens_cpu, device="cuda")
|
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
|
||||||
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
||||||
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
||||||
self.extend_no_prefix = all(x == 0 for x in prefix_lens_cpu)
|
self.extend_no_prefix = all(len(r.prefix_indices) == 0 for r in batch.reqs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_schedule_batch(
|
def from_schedule_batch(
|
||||||
|
|||||||
Reference in New Issue
Block a user