Fix DP load for embedding (#9165)
This commit is contained in:
@@ -612,6 +612,8 @@ class EmbeddingReqInput:
|
||||
|
||||
if self.sampling_params is None:
|
||||
self.sampling_params = [{}] * self.batch_size
|
||||
elif isinstance(self.sampling_params, dict):
|
||||
self.sampling_params = [self.sampling_params] * self.batch_size
|
||||
for i in range(self.batch_size):
|
||||
self.sampling_params[i]["max_new_tokens"] = 0
|
||||
|
||||
@@ -660,6 +662,8 @@ class TokenizedEmbeddingReqInput:
|
||||
token_type_ids: List[int]
|
||||
# Dummy sampling params for compatibility
|
||||
sampling_params: SamplingParams
|
||||
# For data parallel rank routing
|
||||
data_parallel_rank: Optional[int] = None
|
||||
# For dp balance
|
||||
dp_balance_id: int = -1
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ class SessionReqNode:
|
||||
prefix += " -- " + self.childs[0].req.rid
|
||||
ret = self.childs[0]._str_helper(prefix)
|
||||
for child in self.childs[1:]:
|
||||
prefix = " " * len(origin_prefix) + " \- " + child.req.rid
|
||||
prefix = " " * len(origin_prefix) + " \\- " + child.req.rid
|
||||
ret += child._str_helper(prefix)
|
||||
return ret
|
||||
|
||||
|
||||
Reference in New Issue
Block a user