Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -60,20 +59,13 @@ class InputBatch:
|
||||
query_start_loc_np: np.ndarray
|
||||
# [num_reqs]
|
||||
seq_lens: torch.Tensor
|
||||
# [num_reqs]
|
||||
dcp_local_seq_lens: torch.Tensor | None
|
||||
|
||||
# [num_tokens_after_padding]
|
||||
input_ids: torch.Tensor
|
||||
# [num_tokens_after_padding]
|
||||
positions: torch.Tensor
|
||||
# [3, num_tokens_after_padding]
|
||||
mrope_positions: torch.Tensor | None
|
||||
# [num_tokens_after_padding, hidden_size]
|
||||
inputs_embeds: torch.Tensor | None
|
||||
|
||||
# layer_name -> Metadata
|
||||
attn_metadata: dict[str, Any]
|
||||
# layer_name -> slot_mapping
|
||||
slot_mappings: dict[str, torch.Tensor]
|
||||
|
||||
# [total_num_logits]
|
||||
logits_indices: torch.Tensor
|
||||
@@ -90,14 +82,16 @@ class InputBatch:
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
input_buffers: InputBuffers,
|
||||
device: torch.device,
|
||||
) -> "InputBatch":
|
||||
assert 0 < num_reqs <= num_tokens
|
||||
device = input_buffers.device
|
||||
|
||||
req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
|
||||
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
|
||||
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
|
||||
expanded_idx_mapping = idx_mapping
|
||||
expanded_local_pos = torch.zeros(num_reqs, dtype=torch.int32, device=device)
|
||||
|
||||
num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
|
||||
num_scheduled_tokens[-1] += num_tokens % num_reqs
|
||||
assert int(num_scheduled_tokens.sum()) == num_tokens
|
||||
@@ -123,7 +117,6 @@ class InputBatch:
|
||||
input_ids = input_buffers.input_ids[:num_tokens].zero_()
|
||||
positions = input_buffers.positions[:num_tokens].zero_()
|
||||
|
||||
# attn_metadata = defaultdict(lambda: None)
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32)
|
||||
cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32)
|
||||
@@ -141,12 +134,9 @@ class InputBatch:
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_np=query_start_loc_np,
|
||||
seq_lens=seq_lens,
|
||||
dcp_local_seq_lens=None,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
mrope_positions=None,
|
||||
inputs_embeds=None,
|
||||
attn_metadata=None, # type: ignore
|
||||
slot_mappings=None, # type: ignore
|
||||
logits_indices=logits_indices,
|
||||
cu_num_logits=cu_num_logits,
|
||||
cu_num_logits_np=cu_num_logits_np,
|
||||
@@ -507,6 +497,38 @@ def post_update(
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _post_update_pool_kernel(
|
||||
idx_mapping_ptr,
|
||||
num_computed_tokens_ptr,
|
||||
query_start_loc_ptr,
|
||||
):
|
||||
batch_id = tl.program_id(0)
|
||||
query_start = tl.load(query_start_loc_ptr + batch_id)
|
||||
query_end = tl.load(query_start_loc_ptr + batch_id + 1)
|
||||
query_len = query_end - query_start
|
||||
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_id)
|
||||
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
|
||||
tl.store(num_computed_tokens_ptr + req_state_idx, num_computed + query_len)
|
||||
|
||||
|
||||
def post_update_pool(
|
||||
# [num_reqs]
|
||||
idx_mapping: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
num_computed_tokens: torch.Tensor,
|
||||
# [num_reqs + 1]
|
||||
query_start_loc: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
_post_update_pool_kernel[(num_reqs,)](
|
||||
idx_mapping,
|
||||
num_computed_tokens,
|
||||
query_start_loc,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _expand_idx_mapping_kernel(
|
||||
idx_mapping_ptr,
|
||||
|
||||
Reference in New Issue
Block a user