diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index a02673dc3..91735a1b8 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -203,7 +203,6 @@ class RadixAttention(nn.Module): return self.decode_forward(q, k, v, input_metadata) def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata): - k_cache = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id) - v_cache = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id) - k_cache[input_metadata.out_cache_loc] = cache_k - v_cache[input_metadata.out_cache_loc] = cache_v + input_metadata.token_to_kv_pool.set_kv_buffer( + self.layer_id, input_metadata.out_cache_loc, cache_k, cache_v + ) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 68cefbbf9..fef74321a 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -16,7 +16,8 @@ limitations under the License. """Memory pool.""" import logging -from typing import List, Union +from abc import ABC, abstractmethod +from typing import List, Tuple, Union import torch @@ -52,14 +53,21 @@ class ReqToTokenPool: self.free_slots = list(range(self.size)) -class BaseTokenToKVPool: +class BaseTokenToKVPool(ABC): """A memory pool that maps a token to its kv cache locations""" def __init__( self, size: int, + dtype: torch.dtype, ): self.size = size + self.dtype = dtype + if dtype == torch.float8_e5m2: + # NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2 + self.store_dtype = torch.uint8 + else: + self.store_dtype = dtype # We also add one slot. This slot is used for writing dummy output from padded tokens. self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda") @@ -112,6 +120,28 @@ class BaseTokenToKVPool: # We also add one slot. This slot is used for writing dummy output from padded tokens. self.mem_state[0] = False + @abstractmethod + def get_key_buffer(self, layer_id: int) -> torch.Tensor: + raise NotImplementedError() + + @abstractmethod + def get_value_buffer(self, layer_id: int) -> torch.Tensor: + raise NotImplementedError() + + @abstractmethod + def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError() + + @abstractmethod + def set_kv_buffer( + self, + layer_id: int, + loc: torch.Tensor, + cache_k: torch.Tensor, + cache_v: torch.Tensor, + ) -> None: + raise NotImplementedError() + class MHATokenToKVPool(BaseTokenToKVPool): @@ -123,26 +153,52 @@ class MHATokenToKVPool(BaseTokenToKVPool): head_dim: int, layer_num: int, ): - super().__init__(size) + super().__init__(size, dtype) # [size, head_num, head_dim] for each layer self.k_buffer = [ - torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda") + torch.empty( + (size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda" + ) for _ in range(layer_num) ] self.v_buffer = [ - torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda") + torch.empty( + (size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda" + ) for _ in range(layer_num) ] def get_key_buffer(self, layer_id: int): + if self.store_dtype != self.dtype: + return self.k_buffer[layer_id].view(self.dtype) return self.k_buffer[layer_id] def get_value_buffer(self, layer_id: int): + if self.store_dtype != self.dtype: + return self.v_buffer[layer_id].view(self.dtype) return self.v_buffer[layer_id] def get_kv_buffer(self, layer_id: int): - return self.k_buffer[layer_id], self.v_buffer[layer_id] + return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id) + + def set_kv_buffer( + self, + layer_id: int, + loc: torch.Tensor, + cache_k: torch.Tensor, + cache_v: torch.Tensor, + ): + if cache_k.dtype != self.dtype: + cache_k = cache_k.to(self.dtype) + if cache_v.dtype != self.dtype: + cache_v = cache_v.to(self.dtype) + if self.store_dtype != self.dtype: + self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype) + self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype) + else: + self.k_buffer[layer_id][loc] = cache_k + self.v_buffer[layer_id][loc] = cache_v class MLATokenToKVPool(BaseTokenToKVPool): @@ -155,23 +211,41 @@ class MLATokenToKVPool(BaseTokenToKVPool): qk_rope_head_dim: int, layer_num: int, ): - super().__init__(size) + super().__init__(size, dtype) self.kv_lora_rank = kv_lora_rank self.kv_buffer = [ torch.empty( (size + 1, 1, kv_lora_rank + qk_rope_head_dim), - dtype=dtype, + dtype=self.store_dtype, device="cuda", ) for _ in range(layer_num) ] def get_key_buffer(self, layer_id: int): + if self.store_dtype != self.dtype: + return self.kv_buffer[layer_id].view(self.dtype) return self.kv_buffer[layer_id] def get_value_buffer(self, layer_id: int): + if self.store_dtype != self.dtype: + return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype) return self.kv_buffer[layer_id][..., : self.kv_lora_rank] def get_kv_buffer(self, layer_id: int): return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id) + + def set_kv_buffer( + self, + layer_id: int, + loc: torch.Tensor, + cache_k: torch.Tensor, + cache_v: torch.Tensor, + ): + if cache_k.dtype != self.dtype: + cache_k = cache_k.to(self.dtype) + if self.store_dtype != self.dtype: + self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype) + else: + self.kv_buffer[layer_id][loc] = cache_k diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 98daeaece..c107b3bc8 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -315,6 +315,8 @@ def update_flashinfer_indices( num_kv_heads, head_dim, 1, + data_type=model_runner.kv_cache_dtype, + q_data_type=model_runner.dtype, ) else: # extend part @@ -393,6 +395,8 @@ def update_flashinfer_indices( num_kv_heads, head_dim, 1, + data_type=model_runner.kv_cache_dtype, + q_data_type=model_runner.dtype, ) else: # extend part diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index fa55abba6..fecfc2b43 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -311,7 +311,7 @@ class ModelRunner: cell_size = ( (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim) * self.model_config.num_hidden_layers - * torch._utils._element_size(self.dtype) + * torch._utils._element_size(self.kv_cache_dtype) ) else: cell_size = ( @@ -319,7 +319,7 @@ class ModelRunner: * self.model_config.head_dim * self.model_config.num_hidden_layers * 2 - * torch._utils._element_size(self.dtype) + * torch._utils._element_size(self.kv_cache_dtype) ) rest_memory = available_gpu_memory - total_gpu_memory * ( 1 - self.mem_fraction_static @@ -333,6 +333,21 @@ class ModelRunner: max_num_reqs: int = None, max_total_tokens: int = None, ): + if self.server_args.kv_cache_dtype == "auto": + self.kv_cache_dtype = self.dtype + elif self.server_args.kv_cache_dtype == "fp8_e5m2": + if self.server_args.disable_flashinfer or self.server_args.enable_mla: + logger.warning( + "FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype" + ) + self.kv_cache_dtype = self.dtype + else: + self.kv_cache_dtype = torch.float8_e5m2 + else: + raise ValueError( + f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}." + ) + self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory) if max_total_tokens is not None: if max_total_tokens > self.max_total_num_tokens: @@ -369,7 +384,7 @@ class ModelRunner: ): self.token_to_kv_pool = MLATokenToKVPool( self.max_total_num_tokens, - dtype=self.dtype, + dtype=self.kv_cache_dtype, kv_lora_rank=self.model_config.kv_lora_rank, qk_rope_head_dim=self.model_config.qk_rope_head_dim, layer_num=self.model_config.num_hidden_layers, @@ -380,7 +395,7 @@ class ModelRunner: else: self.token_to_kv_pool = MHATokenToKVPool( self.max_total_num_tokens, - dtype=self.dtype, + dtype=self.kv_cache_dtype, head_num=self.model_config.get_num_kv_heads(self.tp_size), head_dim=self.model_config.head_dim, layer_num=self.model_config.num_hidden_layers, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ca27f9748..8a56c02e1 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -33,6 +33,7 @@ class ServerArgs: skip_tokenizer_init: bool = False load_format: str = "auto" dtype: str = "auto" + kv_cache_dtype: str = "auto" trust_remote_code: bool = True context_length: Optional[int] = None quantization: Optional[str] = None @@ -196,6 +197,13 @@ class ServerArgs: '* "float" is shorthand for FP32 precision.\n' '* "float32" for FP32 precision.', ) + parser.add_argument( + "--kv-cache-dtype", + type=str, + default=ServerArgs.kv_cache_dtype, + choices=["auto", "fp8_e5m2"], + help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.', + ) parser.add_argument( "--trust-remote-code", action="store_true",