Fix perf regression on small batch sizes (#3008)
This commit is contained in:
@@ -47,8 +47,8 @@ class RadixAttention(nn.Module):
|
|||||||
self.logit_cap = logit_cap
|
self.logit_cap = logit_cap
|
||||||
self.sliding_window_size = sliding_window_size or -1
|
self.sliding_window_size = sliding_window_size or -1
|
||||||
self.is_cross_attention = is_cross_attention
|
self.is_cross_attention = is_cross_attention
|
||||||
self.k_scale = 1.0
|
self.k_scale = None
|
||||||
self.v_scale = 1.0
|
self.v_scale = None
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ import logging
|
|||||||
import threading
|
import threading
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import List, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import psutil
|
import psutil
|
||||||
@@ -270,13 +270,17 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|||||||
loc: torch.Tensor,
|
loc: torch.Tensor,
|
||||||
cache_k: torch.Tensor,
|
cache_k: torch.Tensor,
|
||||||
cache_v: torch.Tensor,
|
cache_v: torch.Tensor,
|
||||||
k_scale: float = 1.0,
|
k_scale: Optional[float] = None,
|
||||||
v_scale: float = 1.0,
|
v_scale: Optional[float] = None,
|
||||||
):
|
):
|
||||||
layer_id = layer.layer_id
|
layer_id = layer.layer_id
|
||||||
if cache_k.dtype != self.dtype:
|
if cache_k.dtype != self.dtype:
|
||||||
cache_k = (cache_k / k_scale).to(self.dtype)
|
if k_scale is not None:
|
||||||
cache_v = (cache_v / v_scale).to(self.dtype)
|
cache_k.div_(k_scale)
|
||||||
|
if v_scale is not None:
|
||||||
|
cache_v.div_(v_scale)
|
||||||
|
cache_k = cache_k.to(self.dtype)
|
||||||
|
cache_v = cache_v.to(self.dtype)
|
||||||
if self.store_dtype != self.dtype:
|
if self.store_dtype != self.dtype:
|
||||||
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
||||||
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
|
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
|
||||||
|
|||||||
Reference in New Issue
Block a user