Llama3.2 vision model support (#1551)

This commit is contained in:
Liangsheng Yin
2024-10-21 15:01:21 -07:00
committed by GitHub
parent 00611286a1
commit 94cde10920
21 changed files with 1562 additions and 122 deletions

View File

@@ -26,6 +26,8 @@ from typing import List, Tuple, Union
import torch
from sglang.srt.layers.radix_attention import RadixAttention
logger = logging.getLogger(__name__)
@@ -41,13 +43,17 @@ class ReqToTokenPool:
)
self.free_slots = list(range(size))
self.write_records = []
self.use_records = use_records
if use_records:
# records all write operations
if self.use_records:
self.write = self.write_with_records
else:
self.write = self.write_without_records
def write(self, indices, values):
# Keep the signature for type checking, will be initialized during runtime
raise NotImplementedError()
def available_size(self):
return len(self.free_slots)
@@ -154,7 +160,7 @@ class BaseTokenToKVPool:
def set_kv_buffer(
self,
layer_id: int,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
@@ -209,11 +215,12 @@ class MHATokenToKVPool(BaseTokenToKVPool):
def set_kv_buffer(
self,
layer_id: int,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
):
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
if cache_v.dtype != self.dtype:
@@ -265,11 +272,12 @@ class MLATokenToKVPool(BaseTokenToKVPool):
def set_kv_buffer(
self,
layer_id: int,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
):
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
if self.store_dtype != self.dtype:
@@ -324,13 +332,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
def set_kv_buffer(
self,
layer_id: int,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
cache_label: torch.Tensor,
):
# NOTE(Andy): ignore the dtype check
layer_id = layer.layer_id
self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v
self.label_buffer[layer_id][loc] = cache_label