Llama3.2 vision model support (#1551)
This commit is contained in:
@@ -26,6 +26,8 @@ from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -41,13 +43,17 @@ class ReqToTokenPool:
|
||||
)
|
||||
self.free_slots = list(range(size))
|
||||
self.write_records = []
|
||||
self.use_records = use_records
|
||||
|
||||
if use_records:
|
||||
# records all write operations
|
||||
if self.use_records:
|
||||
self.write = self.write_with_records
|
||||
else:
|
||||
self.write = self.write_without_records
|
||||
|
||||
def write(self, indices, values):
|
||||
# Keep the signature for type checking, will be initialized during runtime
|
||||
raise NotImplementedError()
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_slots)
|
||||
|
||||
@@ -154,7 +160,7 @@ class BaseTokenToKVPool:
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer_id: int,
|
||||
layer: RadixAttention,
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
@@ -209,11 +215,12 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer_id: int,
|
||||
layer: RadixAttention,
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
):
|
||||
layer_id = layer.layer_id
|
||||
if cache_k.dtype != self.dtype:
|
||||
cache_k = cache_k.to(self.dtype)
|
||||
if cache_v.dtype != self.dtype:
|
||||
@@ -265,11 +272,12 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer_id: int,
|
||||
layer: RadixAttention,
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
):
|
||||
layer_id = layer.layer_id
|
||||
if cache_k.dtype != self.dtype:
|
||||
cache_k = cache_k.to(self.dtype)
|
||||
if self.store_dtype != self.dtype:
|
||||
@@ -324,13 +332,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer_id: int,
|
||||
layer: RadixAttention,
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
cache_label: torch.Tensor,
|
||||
):
|
||||
# NOTE(Andy): ignore the dtype check
|
||||
layer_id = layer.layer_id
|
||||
self.k_buffer[layer_id][loc] = cache_k
|
||||
self.v_buffer[layer_id][loc] = cache_v
|
||||
self.label_buffer[layer_id][loc] = cache_label
|
||||
|
||||
Reference in New Issue
Block a user