Fix perf regression on small batch sizes (#3008)
This commit is contained in:
@@ -47,8 +47,8 @@ class RadixAttention(nn.Module):
|
||||
self.logit_cap = logit_cap
|
||||
self.sliding_window_size = sliding_window_size or -1
|
||||
self.is_cross_attention = is_cross_attention
|
||||
self.k_scale = 1.0
|
||||
self.v_scale = 1.0
|
||||
self.k_scale = None
|
||||
self.v_scale = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -27,7 +27,7 @@ import logging
|
||||
import threading
|
||||
from enum import IntEnum
|
||||
from functools import wraps
|
||||
from typing import List, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import psutil
|
||||
@@ -270,13 +270,17 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
k_scale: Optional[float] = None,
|
||||
v_scale: Optional[float] = None,
|
||||
):
|
||||
layer_id = layer.layer_id
|
||||
if cache_k.dtype != self.dtype:
|
||||
cache_k = (cache_k / k_scale).to(self.dtype)
|
||||
cache_v = (cache_v / v_scale).to(self.dtype)
|
||||
if k_scale is not None:
|
||||
cache_k.div_(k_scale)
|
||||
if v_scale is not None:
|
||||
cache_v.div_(v_scale)
|
||||
cache_k = cache_k.to(self.dtype)
|
||||
cache_v = cache_v.to(self.dtype)
|
||||
if self.store_dtype != self.dtype:
|
||||
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
||||
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
|
||||
|
||||
Reference in New Issue
Block a user