[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
175
vllm/distributed/kv_transfer/kv_lookup_buffer/base.py
Normal file
175
vllm/distributed/kv_transfer/kv_lookup_buffer/base.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file contains a new class `KVLookupBufferBase` that allows developers to
|
||||
think of KV cache operations as inserting new KV cache entries (`insert`)
|
||||
into the lookup buffer and querying existing KV caches (`drop_select`)
|
||||
from the lookup buffer.
|
||||
|
||||
This file also contains a new class `KVStoreBufferBase` that allows developers
|
||||
to manage the KVCache buffer as a simple key-value storage buffer with basic
|
||||
put/get operations.
|
||||
|
||||
These classes above are abstracted behind class `KVCacheBufferBase`.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class KVCacheBufferBase(ABC):
|
||||
"""
|
||||
Abstract base class for a KVCache buffer.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""Close the buffer and release resources.
|
||||
|
||||
This method is responsible for cleaning up resources related to the
|
||||
KVCache buffer when it is no longer needed.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class KVLookupBufferBase(KVCacheBufferBase):
|
||||
"""
|
||||
Abstract base class for a KVCache lookup buffer.
|
||||
|
||||
This class provides an abstraction for a key-value (KV) cache lookup buffer.
|
||||
|
||||
The key of the lookup buffer:
|
||||
- input_tokens: token IDs of the request
|
||||
- roi: a binary mask on top of input_tokens.
|
||||
- Purpose of roi: Since KV cache may only be available for a subset of
|
||||
tokens in the input (for example, when vLLM is connected to an external
|
||||
KV cache service), roi specifies the subset of tokens that the KV cache
|
||||
is associated with.
|
||||
- NOTE: roi can be further extended to describe which part of KV the
|
||||
current process is holding (each process may only hold a part of KV
|
||||
due to TP and PP). This is not implemented for now.
|
||||
|
||||
The value of the lookup buffer:
|
||||
- key: the key tensor in the KV cache
|
||||
- value: the value tensor in the KV cache
|
||||
- hidden: the final hidden state generated by model forwarding. This allows
|
||||
vLLM to bypass further model forwarding by transmitting the hidden state.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
|
||||
key: torch.Tensor, value: torch.Tensor,
|
||||
hidden: torch.Tensor) -> None:
|
||||
"""Insert into the lookup buffer.
|
||||
|
||||
The functionality is similar to the following python statement
|
||||
```
|
||||
buffer[input_tokens, roi] = [key, value, hidden]
|
||||
```
|
||||
|
||||
FIXME: in the future, we should only have two arguments, key and value,
|
||||
where key is a tensor dict and value is a tensor dict.
|
||||
|
||||
FIXME: we should transmit both sampler outputs and the hidden states.
|
||||
|
||||
Args:
|
||||
input_tokens (torch.Tensor): token IDs.
|
||||
roi (torch.Tensor): A binary mask on top of the input tokens
|
||||
key (torch.Tensor): The key tensor in the KV cache.
|
||||
value (torch.Tensor): The value tensor in the KV cache.
|
||||
hidden (torch.Tensor): The final hidden state tensor generated
|
||||
during model forwarding to bypass model
|
||||
forwarding.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def drop_select(
|
||||
self, input_tokens: Optional[torch.Tensor],
|
||||
roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
|
||||
"""Select and *drop* KV cache entries from the lookup buffer.
|
||||
|
||||
The functionality is similar to the following python statements
|
||||
```
|
||||
ret = buffer.pop(input_tokens, roi)
|
||||
return ret
|
||||
```
|
||||
|
||||
If `input_tokens` and `roi` is `None`, it means selecting any of the
|
||||
KV caches in the buffer, return, and remove it from the buffer, useful
|
||||
when offloading KV cache to KV cache storage service.
|
||||
|
||||
Args:
|
||||
input_tokens (torch.Tensor): token IDs.
|
||||
roi (torch.Tensor): A binary mask on top of the input tokens
|
||||
|
||||
Returns:
|
||||
list[Optional[torch.Tensor]]: A list of tensors. Can be None.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class KVStoreBufferBase(KVCacheBufferBase):
|
||||
"""
|
||||
Abstract base class for a KVCache storage buffer with key-value semantics.
|
||||
This class provides a simple key-value storage buffer abstract with basic
|
||||
put/get operations, which enables flexible KVCache transfer granular
|
||||
control.
|
||||
|
||||
The functionality is similar to a distributed key-value store, where:
|
||||
- Key: A unique string identifier for the cached entry
|
||||
- Value:
|
||||
- Tensor to be stored and retrieved
|
||||
- None (indicating deletion or empty value)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
value: Optional[torch.Tensor],
|
||||
) -> None:
|
||||
"""Store a key-value pair in the buffer.
|
||||
|
||||
Args:
|
||||
key (str): Unique identifier for a tensor, this tensor could be the
|
||||
key cache tensor, value cache tensor, or hidden state tensor
|
||||
generated during model forwarding.
|
||||
|
||||
value (Optional[torch.Tensor]): Tensor to be stored.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""Retrieve a value from the buffer by key.
|
||||
|
||||
Args:
|
||||
key (str): Unique identifier for a tensor, this tensor could be the
|
||||
key cache tensor, value cache tensor, or hidden state tensor
|
||||
generated during model forwarding.
|
||||
|
||||
Returns:
|
||||
Optional[torch.Tensor]: Stored tensor if exists, None otherwise.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
Reference in New Issue
Block a user