diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index a449d7188..0d46e7bba 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -47,8 +47,8 @@ class RadixAttention(nn.Module): self.logit_cap = logit_cap self.sliding_window_size = sliding_window_size or -1 self.is_cross_attention = is_cross_attention - self.k_scale = 1.0 - self.v_scale = 1.0 + self.k_scale = None + self.v_scale = None def forward( self, diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index e30736722..7b9b35611 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -27,7 +27,7 @@ import logging import threading from enum import IntEnum from functools import wraps -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import psutil @@ -270,13 +270,17 @@ class MHATokenToKVPool(BaseTokenToKVPool): loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, - k_scale: float = 1.0, - v_scale: float = 1.0, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, ): layer_id = layer.layer_id if cache_k.dtype != self.dtype: - cache_k = (cache_k / k_scale).to(self.dtype) - cache_v = (cache_v / v_scale).to(self.dtype) + if k_scale is not None: + cache_k.div_(k_scale) + if v_scale is not None: + cache_v.div_(v_scale) + cache_k = cache_k.to(self.dtype) + cache_v = cache_v.to(self.dtype) if self.store_dtype != self.dtype: self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype) self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)