Fix perf regression on small batch sizes (#3008)

This commit is contained in:
Lianmin Zheng
2025-01-20 03:39:49 -08:00
committed by GitHub
parent 10bfce71b3
commit dc1881326f
2 changed files with 11 additions and 7 deletions

View File

@@ -27,7 +27,7 @@ import logging
import threading
from enum import IntEnum
from functools import wraps
from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union
import numpy as np
import psutil
@@ -270,13 +270,17 @@ class MHATokenToKVPool(BaseTokenToKVPool):
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
k_scale: float = 1.0,
v_scale: float = 1.0,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
):
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
cache_k = (cache_k / k_scale).to(self.dtype)
cache_v = (cache_v / v_scale).to(self.dtype)
if k_scale is not None:
cache_k.div_(k_scale)
if v_scale is not None:
cache_v.div_(v_scale)
cache_k = cache_k.to(self.dtype)
cache_v = cache_v.to(self.dtype)
if self.store_dtype != self.dtype:
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)