minor: some potential bugs (#1044)
This commit is contained in:
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
||||
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional
|
||||
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
@@ -30,12 +30,13 @@ class ChunkCache(BasePrefixCache):
|
||||
def reset(self):
|
||||
self.entries = {}
|
||||
|
||||
def match_prefix(self, rid, **kwargs):
|
||||
def match_prefix(self, rid: int, key: List[int]):
|
||||
if rid not in self.entries:
|
||||
return [], None
|
||||
|
||||
entry = self.entries[rid]
|
||||
return entry.value, entry
|
||||
max_prefix_len = len(key)
|
||||
return entry.value[:max_prefix_len], entry
|
||||
|
||||
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
||||
if token_ids is None:
|
||||
|
||||
@@ -140,13 +140,13 @@ class InputMetadata:
|
||||
if self.forward_mode == ForwardMode.DECODE:
|
||||
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
|
||||
else:
|
||||
prefix_lens_cpu = [
|
||||
extend_lens_cpu = [
|
||||
len(r.fill_ids) - len(r.prefix_indices) for r in batch.reqs
|
||||
]
|
||||
self.extend_seq_lens = torch.tensor(prefix_lens_cpu, device="cuda")
|
||||
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
|
||||
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
||||
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
||||
self.extend_no_prefix = all(x == 0 for x in prefix_lens_cpu)
|
||||
self.extend_no_prefix = all(len(r.prefix_indices) == 0 for r in batch.reqs)
|
||||
|
||||
@classmethod
|
||||
def from_schedule_batch(
|
||||
|
||||
Reference in New Issue
Block a user