Fix DP load for embedding (#9165)
This commit is contained in:
@@ -612,6 +612,8 @@ class EmbeddingReqInput:
|
|||||||
|
|
||||||
if self.sampling_params is None:
|
if self.sampling_params is None:
|
||||||
self.sampling_params = [{}] * self.batch_size
|
self.sampling_params = [{}] * self.batch_size
|
||||||
|
elif isinstance(self.sampling_params, dict):
|
||||||
|
self.sampling_params = [self.sampling_params] * self.batch_size
|
||||||
for i in range(self.batch_size):
|
for i in range(self.batch_size):
|
||||||
self.sampling_params[i]["max_new_tokens"] = 0
|
self.sampling_params[i]["max_new_tokens"] = 0
|
||||||
|
|
||||||
@@ -660,6 +662,8 @@ class TokenizedEmbeddingReqInput:
|
|||||||
token_type_ids: List[int]
|
token_type_ids: List[int]
|
||||||
# Dummy sampling params for compatibility
|
# Dummy sampling params for compatibility
|
||||||
sampling_params: SamplingParams
|
sampling_params: SamplingParams
|
||||||
|
# For data parallel rank routing
|
||||||
|
data_parallel_rank: Optional[int] = None
|
||||||
# For dp balance
|
# For dp balance
|
||||||
dp_balance_id: int = -1
|
dp_balance_id: int = -1
|
||||||
|
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class SessionReqNode:
|
|||||||
prefix += " -- " + self.childs[0].req.rid
|
prefix += " -- " + self.childs[0].req.rid
|
||||||
ret = self.childs[0]._str_helper(prefix)
|
ret = self.childs[0]._str_helper(prefix)
|
||||||
for child in self.childs[1:]:
|
for child in self.childs[1:]:
|
||||||
prefix = " " * len(origin_prefix) + " \- " + child.req.rid
|
prefix = " " * len(origin_prefix) + " \\- " + child.req.rid
|
||||||
ret += child._str_helper(prefix)
|
ret += child._str_helper(prefix)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user