From fb7421db0ddcb263b2cd1d8bbbe63282c97606aa Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Sun, 11 Aug 2024 22:35:44 -0700 Subject: [PATCH] minor: some potential bugs (#1044) --- python/sglang/srt/mem_cache/chunk_cache.py | 7 ++++--- python/sglang/srt/model_executor/forward_batch_info.py | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index c6e6507a0..35b9171e5 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -2,7 +2,7 @@ from __future__ import annotations """Cache for chunked prefill, used when RadixCache is disabled.""" -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Callable, List, Optional from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool @@ -30,12 +30,13 @@ class ChunkCache(BasePrefixCache): def reset(self): self.entries = {} - def match_prefix(self, rid, **kwargs): + def match_prefix(self, rid: int, key: List[int]): if rid not in self.entries: return [], None entry = self.entries[rid] - return entry.value, entry + max_prefix_len = len(key) + return entry.value[:max_prefix_len], entry def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None): if token_ids is None: diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index d54b14ef2..eb7aaaf2c 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -140,13 +140,13 @@ class InputMetadata: if self.forward_mode == ForwardMode.DECODE: self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None else: - prefix_lens_cpu = [ + extend_lens_cpu = [ len(r.fill_ids) - len(r.prefix_indices) for r in batch.reqs ] - self.extend_seq_lens = torch.tensor(prefix_lens_cpu, device="cuda") + self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda") self.extend_start_loc = torch.zeros_like(self.seq_lens) self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) - self.extend_no_prefix = all(x == 0 for x in prefix_lens_cpu) + self.extend_no_prefix = all(len(r.prefix_indices) == 0 for r in batch.reqs) @classmethod def from_schedule_batch(