Input_embeds support (#2052)

This commit is contained in:
Rin Intachuen
2024-11-25 19:35:04 -05:00
committed by GitHub
parent 1f76fc6e3f
commit 1aea19f64b
9 changed files with 204 additions and 15 deletions

View File

@@ -178,6 +178,7 @@ class Req:
origin_input_ids: Tuple[int],
sampling_params: SamplingParams,
lora_path: Optional[str] = None,
input_embeds: Optional[List[List[float]]] = None,
session_id: Optional[str] = None,
):
# Input and output info
@@ -191,6 +192,7 @@ class Req:
self.sampling_params = sampling_params
self.lora_path = lora_path
self.input_embeds = input_embeds
# Memory pool info
self.req_pool_idx = None
@@ -448,6 +450,7 @@ class ScheduleBatch:
# Batched arguments to model runner
input_ids: torch.Tensor = None
input_embeds: torch.Tensor = None
req_pool_indices: torch.Tensor = None
seq_lens: torch.Tensor = None
# The output locations of the KV cache
@@ -631,6 +634,9 @@ class ScheduleBatch:
req_pool_indices = self.alloc_req_slots(bs)
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
input_embeds = []
pt = 0
for i, req in enumerate(reqs):
already_computed = (
req.extend_logprob_start_len + 1 + req.cached_tokens
@@ -649,6 +655,11 @@ class ScheduleBatch:
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
)
# If input_embeds are available, store them
if req.input_embeds is not None:
# If req.input_embeds is already a list, append its content directly
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
# Compute the relative logprob_start_len in an extend batch
if req.logprob_start_len >= pre_len:
extend_logprob_start_len = min(
@@ -671,6 +682,12 @@ class ScheduleBatch:
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.input_embeds = (
torch.tensor(input_embeds).to(self.device, non_blocking=True)
if input_embeds
else None
)
self.out_cache_loc = out_cache_loc
self.seq_lens_sum = sum(seq_lens)
@@ -1053,6 +1070,7 @@ class ScheduleBatch:
encoder_out_cache_loc=self.encoder_out_cache_loc,
lora_paths=[req.lora_path for req in self.reqs],
sampling_info=self.sampling_info,
input_embeds=self.input_embeds,
)
def copy(self):
@@ -1123,6 +1141,9 @@ class ModelWorkerBatch:
# Sampling info
sampling_info: SamplingBatchInfo
# The input Embeds
input_embeds: Optional[torch.tensor] = None
@triton.jit
def write_req_to_token_pool_triton(