Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -53,6 +53,13 @@ class CachedRequestState:
|
||||
pooling_params: PoolingParams | None = None
|
||||
pooling_states: PoolingStates | None = None
|
||||
|
||||
# for multi layer eagle proposer
|
||||
cached_len: torch.Tensor | None = None
|
||||
cached_token_ids: torch.Tensor | None = None
|
||||
cached_hidden_states: torch.Tensor | None = None
|
||||
cached_slot_mappings: torch.Tensor | None = None
|
||||
cached_positions: torch.Tensor | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
self.prompt_token_ids, self.prompt_embeds
|
||||
@@ -95,6 +102,8 @@ class InputBatch:
|
||||
is_spec_decode: bool = False,
|
||||
is_pooling_model: bool = False,
|
||||
cp_kv_cache_interleave_size: int = 1,
|
||||
multi_layer_eagle_num: int = 0,
|
||||
hidden_size: int | None = None,
|
||||
):
|
||||
self.is_pooling_model = is_pooling_model
|
||||
self.is_spec_decode = is_spec_decode
|
||||
@@ -211,6 +220,46 @@ class InputBatch:
|
||||
)
|
||||
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
|
||||
|
||||
# Multi layer eagle
|
||||
self.multi_layer_eagle_num = multi_layer_eagle_num
|
||||
if multi_layer_eagle_num > 0:
|
||||
self.cached_len = torch.zeros(
|
||||
(max_num_reqs,), dtype=torch.int64, device=device
|
||||
)
|
||||
self.cached_token_ids = torch.zeros(
|
||||
(
|
||||
max_num_reqs,
|
||||
multi_layer_eagle_num,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.cached_hidden_states = torch.zeros(
|
||||
(
|
||||
max_num_reqs,
|
||||
multi_layer_eagle_num,
|
||||
hidden_size,
|
||||
),
|
||||
dtype=torch.float,
|
||||
device=device,
|
||||
)
|
||||
self.cached_slot_mappings = torch.zeros(
|
||||
(
|
||||
max_num_reqs,
|
||||
multi_layer_eagle_num,
|
||||
),
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
self.cached_positions = torch.zeros(
|
||||
(
|
||||
max_num_reqs,
|
||||
multi_layer_eagle_num,
|
||||
),
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# lora related
|
||||
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
|
||||
self.lora_id_to_request_ids: dict[int, set[str]] = {}
|
||||
@@ -425,6 +474,13 @@ class InputBatch:
|
||||
# Speculative decoding: by default 1 token is generated.
|
||||
self.num_accepted_tokens_cpu[req_index] = 1
|
||||
|
||||
if self.multi_layer_eagle_num > 0:
|
||||
self.cached_len[req_index] = request.cached_len
|
||||
self.cached_token_ids[req_index] = request.cached_token_ids
|
||||
self.cached_hidden_states[req_index] = request.cached_hidden_states
|
||||
self.cached_slot_mappings[req_index] = request.cached_slot_mappings
|
||||
self.cached_positions[req_index] = request.cached_positions
|
||||
|
||||
# Add request lora ID
|
||||
if request.lora_request:
|
||||
lora_id = request.lora_request.lora_int_id
|
||||
@@ -623,6 +679,24 @@ class InputBatch:
|
||||
self.allowed_token_ids_mask_cpu_tensor[i1],
|
||||
)
|
||||
|
||||
if self.multi_layer_eagle_num > 0:
|
||||
self.cached_len[i1], self.cached_len[i2] = (
|
||||
self.cached_len[i2],
|
||||
self.cached_len[i1],
|
||||
)
|
||||
self.cached_token_ids[[i1, i2], ...] = self.cached_token_ids[
|
||||
[i2, i1], ...
|
||||
]
|
||||
self.cached_hidden_states[[i1, i2], ...] = self.cached_hidden_states[
|
||||
[i2, i1], ...
|
||||
]
|
||||
self.cached_slot_mappings[[i1, i2], ...] = self.cached_slot_mappings[
|
||||
[i2, i1], ...
|
||||
]
|
||||
self.cached_positions[[i1, i2], ...] = self.cached_positions[
|
||||
[i2, i1], ...
|
||||
]
|
||||
|
||||
def condense(self) -> None:
|
||||
"""Slide non-empty requests down into lower, empty indices.
|
||||
|
||||
@@ -745,6 +819,21 @@ class InputBatch:
|
||||
if bad_words_token_ids is not None:
|
||||
self.bad_words_token_ids[empty_index] = bad_words_token_ids
|
||||
|
||||
if self.multi_layer_eagle_num > 0:
|
||||
self.cached_len[empty_index] = self.cached_len[last_req_index]
|
||||
self.cached_token_ids[empty_index] = self.cached_token_ids[
|
||||
last_req_index
|
||||
]
|
||||
self.cached_hidden_states[empty_index] = self.cached_hidden_states[
|
||||
last_req_index
|
||||
]
|
||||
self.cached_slot_mappings[empty_index] = self.cached_slot_mappings[
|
||||
last_req_index
|
||||
]
|
||||
self.cached_positions[empty_index] = self.cached_positions[
|
||||
last_req_index
|
||||
]
|
||||
|
||||
# Decrement last_req_index since it is now empty.
|
||||
last_req_index -= 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user