Input_embeds support (#2052)
This commit is contained in:
@@ -178,6 +178,7 @@ class Req:
|
||||
origin_input_ids: Tuple[int],
|
||||
sampling_params: SamplingParams,
|
||||
lora_path: Optional[str] = None,
|
||||
input_embeds: Optional[List[List[float]]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
):
|
||||
# Input and output info
|
||||
@@ -191,6 +192,7 @@ class Req:
|
||||
|
||||
self.sampling_params = sampling_params
|
||||
self.lora_path = lora_path
|
||||
self.input_embeds = input_embeds
|
||||
|
||||
# Memory pool info
|
||||
self.req_pool_idx = None
|
||||
@@ -448,6 +450,7 @@ class ScheduleBatch:
|
||||
|
||||
# Batched arguments to model runner
|
||||
input_ids: torch.Tensor = None
|
||||
input_embeds: torch.Tensor = None
|
||||
req_pool_indices: torch.Tensor = None
|
||||
seq_lens: torch.Tensor = None
|
||||
# The output locations of the KV cache
|
||||
@@ -631,6 +634,9 @@ class ScheduleBatch:
|
||||
req_pool_indices = self.alloc_req_slots(bs)
|
||||
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
||||
|
||||
input_embeds = []
|
||||
|
||||
pt = 0
|
||||
for i, req in enumerate(reqs):
|
||||
already_computed = (
|
||||
req.extend_logprob_start_len + 1 + req.cached_tokens
|
||||
@@ -649,6 +655,11 @@ class ScheduleBatch:
|
||||
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
|
||||
)
|
||||
|
||||
# If input_embeds are available, store them
|
||||
if req.input_embeds is not None:
|
||||
# If req.input_embeds is already a list, append its content directly
|
||||
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
||||
|
||||
# Compute the relative logprob_start_len in an extend batch
|
||||
if req.logprob_start_len >= pre_len:
|
||||
extend_logprob_start_len = min(
|
||||
@@ -671,6 +682,12 @@ class ScheduleBatch:
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.input_embeds = (
|
||||
torch.tensor(input_embeds).to(self.device, non_blocking=True)
|
||||
if input_embeds
|
||||
else None
|
||||
)
|
||||
|
||||
self.out_cache_loc = out_cache_loc
|
||||
|
||||
self.seq_lens_sum = sum(seq_lens)
|
||||
@@ -1053,6 +1070,7 @@ class ScheduleBatch:
|
||||
encoder_out_cache_loc=self.encoder_out_cache_loc,
|
||||
lora_paths=[req.lora_path for req in self.reqs],
|
||||
sampling_info=self.sampling_info,
|
||||
input_embeds=self.input_embeds,
|
||||
)
|
||||
|
||||
def copy(self):
|
||||
@@ -1123,6 +1141,9 @@ class ModelWorkerBatch:
|
||||
# Sampling info
|
||||
sampling_info: SamplingBatchInfo
|
||||
|
||||
# The input Embeds
|
||||
input_embeds: Optional[torch.tensor] = None
|
||||
|
||||
|
||||
@triton.jit
|
||||
def write_req_to_token_pool_triton(
|
||||
|
||||
Reference in New Issue
Block a user