Fix perf regression on small batch sizes (#3008)

This commit is contained in:
Lianmin Zheng
2025-01-20 03:39:49 -08:00
committed by GitHub
parent 10bfce71b3
commit dc1881326f
2 changed files with 11 additions and 7 deletions

View File

@@ -47,8 +47,8 @@ class RadixAttention(nn.Module):
self.logit_cap = logit_cap self.logit_cap = logit_cap
self.sliding_window_size = sliding_window_size or -1 self.sliding_window_size = sliding_window_size or -1
self.is_cross_attention = is_cross_attention self.is_cross_attention = is_cross_attention
self.k_scale = 1.0 self.k_scale = None
self.v_scale = 1.0 self.v_scale = None
def forward( def forward(
self, self,

View File

@@ -27,7 +27,7 @@ import logging
import threading import threading
from enum import IntEnum from enum import IntEnum
from functools import wraps from functools import wraps
from typing import List, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import psutil import psutil
@@ -270,13 +270,17 @@ class MHATokenToKVPool(BaseTokenToKVPool):
loc: torch.Tensor, loc: torch.Tensor,
cache_k: torch.Tensor, cache_k: torch.Tensor,
cache_v: torch.Tensor, cache_v: torch.Tensor,
k_scale: float = 1.0, k_scale: Optional[float] = None,
v_scale: float = 1.0, v_scale: Optional[float] = None,
): ):
layer_id = layer.layer_id layer_id = layer.layer_id
if cache_k.dtype != self.dtype: if cache_k.dtype != self.dtype:
cache_k = (cache_k / k_scale).to(self.dtype) if k_scale is not None:
cache_v = (cache_v / v_scale).to(self.dtype) cache_k.div_(k_scale)
if v_scale is not None:
cache_v.div_(v_scale)
cache_k = cache_k.to(self.dtype)
cache_v = cache_v.to(self.dtype)
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype) self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype) self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)