Fix perf regression on small batch sizes (#3008)

This commit is contained in:
Lianmin Zheng
2025-01-20 03:39:49 -08:00
committed by GitHub
parent 10bfce71b3
commit dc1881326f
2 changed files with 11 additions and 7 deletions

View File

@@ -47,8 +47,8 @@ class RadixAttention(nn.Module):
self.logit_cap = logit_cap
self.sliding_window_size = sliding_window_size or -1
self.is_cross_attention = is_cross_attention
self.k_scale = 1.0
self.v_scale = 1.0
self.k_scale = None
self.v_scale = None
def forward(
self,

View File

@@ -27,7 +27,7 @@ import logging
import threading
from enum import IntEnum
from functools import wraps
from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union
import numpy as np
import psutil
@@ -270,13 +270,17 @@ class MHATokenToKVPool(BaseTokenToKVPool):
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
k_scale: float = 1.0,
v_scale: float = 1.0,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
):
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
cache_k = (cache_k / k_scale).to(self.dtype)
cache_v = (cache_v / v_scale).to(self.dtype)
if k_scale is not None:
cache_k.div_(k_scale)
if v_scale is not None:
cache_v.div_(v_scale)
cache_k = cache_k.to(self.dtype)
cache_v = cache_v.to(self.dtype)
if self.store_dtype != self.dtype:
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)