bugfix(MC2): refactor the comm group of MC2 to be compatible with PP (#7291)
### What this PR does / why we need it?
This PR refactors the communication group of MC2 to keep it consistent
with vllm's EP group, making it compatible with PP.
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
@@ -453,7 +453,6 @@ class RemoteEPDServer(RemoteOpenAIServer):
|
||||
self.env_dict.update(env_dict)
|
||||
|
||||
self.env_dict["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
|
||||
self.env_dict["VLLM_USE_V1"] = "1"
|
||||
self.env_dict["PYTORCH_NPU_ALLOC_CONF"] = "expandable_segments:True"
|
||||
self.env_dict["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
@@ -626,6 +625,126 @@ class DisaggEpdProxy(RemoteEPDServer):
|
||||
super()._terminate_server()
|
||||
|
||||
|
||||
_DP_RUNNER_START_TIMEOUT_SECONDS = 900.0
|
||||
_DP_RUNNER_REQUEST_TIMEOUT_SECONDS = 900.0
|
||||
_DP_RUNNER_SHUTDOWN_TIMEOUT_SECONDS = 30.0
|
||||
|
||||
|
||||
def _split_data_parallel_indices(num_items: int, dp_size: int) -> list[list[int]]:
|
||||
if num_items < 0:
|
||||
raise ValueError("num_items must be non-negative")
|
||||
if dp_size <= 0:
|
||||
raise ValueError("dp_size must be positive")
|
||||
|
||||
floor = num_items // dp_size
|
||||
remainder = num_items % dp_size
|
||||
|
||||
def start(rank: int) -> int:
|
||||
return rank * floor + min(rank, remainder)
|
||||
|
||||
return [list(range(start(rank), start(rank + 1))) for rank in range(dp_size)]
|
||||
|
||||
|
||||
def _slice_optional_inputs(inputs: PromptImageInput | PromptAudioInput | PromptVideoInput | None, indices: list[int]):
|
||||
if inputs is None:
|
||||
return None
|
||||
return [inputs[index] for index in indices]
|
||||
|
||||
|
||||
def _slice_list_inputs(items: list[Any], indices: list[int]) -> list[Any]:
|
||||
return [items[index] for index in indices]
|
||||
|
||||
|
||||
def _merge_data_parallel_results(total_items: int, shard_results: list[tuple[list[int], list[Any]]]) -> list[Any]:
|
||||
merged: list[Any] = [None] * total_items
|
||||
for indices, results in shard_results:
|
||||
if not indices:
|
||||
continue
|
||||
if len(indices) != len(results):
|
||||
raise RuntimeError("Mismatched result count returned by data parallel worker")
|
||||
for index, result in zip(indices, results):
|
||||
merged[index] = result
|
||||
|
||||
if any(result is None for result in merged):
|
||||
raise RuntimeError("Some data parallel results were not returned")
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def _normalize_score_inputs(text_1: str | list[str], text_2: str | list[str]) -> tuple[list[str], list[str]]:
|
||||
if isinstance(text_1, str) and isinstance(text_2, str):
|
||||
return [text_1], [text_2]
|
||||
if isinstance(text_1, str):
|
||||
return [text_1] * len(text_2), list(text_2)
|
||||
if isinstance(text_2, str):
|
||||
return list(text_1), [text_2] * len(text_1)
|
||||
if len(text_1) != len(text_2):
|
||||
raise ValueError("`text_1` and `text_2` must have the same length")
|
||||
return list(text_1), list(text_2)
|
||||
|
||||
|
||||
def _run_vllm_runner_dp_worker(conn, llm_kwargs: dict[str, Any], dp_rank: int, dp_size: int, master_port: int) -> None:
|
||||
llm = None
|
||||
try:
|
||||
os.environ["VLLM_DP_RANK"] = str(dp_rank)
|
||||
os.environ["VLLM_DP_RANK_LOCAL"] = str(dp_rank)
|
||||
os.environ["VLLM_DP_SIZE"] = str(dp_size)
|
||||
os.environ["VLLM_DP_MASTER_IP"] = "127.0.0.1"
|
||||
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
llm = LLM(**llm_kwargs)
|
||||
conn.send({"status": "ready", "rank": dp_rank})
|
||||
|
||||
while True:
|
||||
request = conn.recv()
|
||||
command = request["command"]
|
||||
if command == "shutdown":
|
||||
break
|
||||
|
||||
result: Any
|
||||
if command == "generate":
|
||||
req_outputs = llm.generate(
|
||||
request["inputs"], sampling_params=request["sampling_params"], **request["kwargs"]
|
||||
)
|
||||
result = VllmRunner._finalize_generate_outputs(req_outputs)
|
||||
elif command == "generate_w_logprobs":
|
||||
req_outputs = llm.generate(
|
||||
request["inputs"], sampling_params=request["sampling_params"], **request["kwargs"]
|
||||
)
|
||||
result = VllmRunner._final_steps_generate_w_logprobs(req_outputs)
|
||||
elif command == "classify":
|
||||
req_outputs = llm.classify(request["prompts"])
|
||||
result = [req_output.outputs.probs for req_output in req_outputs]
|
||||
elif command == "embed":
|
||||
req_outputs = llm.embed(request["inputs"], *request["args"], **request["kwargs"])
|
||||
result = [req_output.outputs.embedding for req_output in req_outputs]
|
||||
elif command == "encode":
|
||||
req_outputs = llm.encode(request["prompts"])
|
||||
result = [req_output.outputs.data for req_output in req_outputs]
|
||||
elif command == "reward":
|
||||
req_outputs = llm.reward(request["prompts"])
|
||||
result = [req_output.outputs.data for req_output in req_outputs]
|
||||
elif command == "score":
|
||||
req_outputs = llm.score(request["text_1"], request["text_2"], *request["args"], **request["kwargs"])
|
||||
result = [req_output.outputs.score for req_output in req_outputs]
|
||||
else:
|
||||
raise ValueError(f"Unsupported data parallel command: {command}")
|
||||
|
||||
conn.send({"status": "ok", "rank": dp_rank, "indices": request["indices"], "result": result})
|
||||
except Exception:
|
||||
with contextlib.suppress(Exception):
|
||||
conn.send({"status": "error", "rank": dp_rank, "traceback": traceback.format_exc()})
|
||||
raise
|
||||
finally:
|
||||
if llm is not None:
|
||||
del llm
|
||||
clear_ascend_config()
|
||||
cleanup_dist_env_and_memory()
|
||||
with contextlib.suppress(Exception):
|
||||
conn.close()
|
||||
|
||||
|
||||
class VllmRunner:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -645,6 +764,10 @@ class VllmRunner:
|
||||
quantization: str | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
data_parallel_size = int(kwargs.get("data_parallel_size", 1))
|
||||
if data_parallel_size > 1:
|
||||
raise ValueError("VllmRunner does not support `data_parallel_size > 1`; use `DPVllmRunner` instead.")
|
||||
|
||||
self.model = LLM(
|
||||
model=model_name,
|
||||
runner=runner,
|
||||
@@ -664,6 +787,22 @@ class VllmRunner:
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _finalize_generate_outputs(req_outputs: list[RequestOutput]) -> list[tuple[list[list[int]], list[str]]]:
|
||||
outputs: list[tuple[list[list[int]], list[str]]] = []
|
||||
for req_output in req_outputs:
|
||||
prompt_str = req_output.prompt
|
||||
prompt_ids = req_output.prompt_token_ids
|
||||
req_sample_output_ids: list[list[int]] = []
|
||||
req_sample_output_strs: list[str] = []
|
||||
for sample in req_output.outputs:
|
||||
output_str = sample.text
|
||||
output_ids = list(sample.token_ids)
|
||||
req_sample_output_ids.append(prompt_ids + output_ids)
|
||||
req_sample_output_strs.append((prompt_str or "") + output_str)
|
||||
outputs.append((req_sample_output_ids, req_sample_output_strs))
|
||||
return outputs
|
||||
|
||||
def get_inputs(
|
||||
self,
|
||||
prompts: list[str] | list[torch.Tensor] | list[int],
|
||||
@@ -698,7 +837,7 @@ class VllmRunner:
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: list[str] | list[torch.Tensor],
|
||||
prompts: list[str] | list[torch.Tensor] | list[list[int]],
|
||||
sampling_params: SamplingParams,
|
||||
images: PromptImageInput | None = None,
|
||||
videos: PromptVideoInput | None = None,
|
||||
@@ -706,22 +845,8 @@ class VllmRunner:
|
||||
**kwargs: Any,
|
||||
) -> list[tuple[list[list[int]], list[str]]]:
|
||||
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
||||
|
||||
req_outputs = self.model.generate(inputs, sampling_params=sampling_params, **kwargs)
|
||||
|
||||
outputs: list[tuple[list[list[int]], list[str]]] = []
|
||||
for req_output in req_outputs:
|
||||
prompt_str = req_output.prompt
|
||||
prompt_ids = req_output.prompt_token_ids
|
||||
req_sample_output_ids: list[list[int]] = []
|
||||
req_sample_output_strs: list[str] = []
|
||||
for sample in req_output.outputs:
|
||||
output_str = sample.text
|
||||
output_ids = list(sample.token_ids)
|
||||
req_sample_output_ids.append(prompt_ids + output_ids)
|
||||
req_sample_output_strs.append((prompt_str or "") + output_str)
|
||||
outputs.append((req_sample_output_ids, req_sample_output_strs))
|
||||
return outputs
|
||||
return self._finalize_generate_outputs(req_outputs)
|
||||
|
||||
@staticmethod
|
||||
def _final_steps_generate_w_logprobs(
|
||||
@@ -760,7 +885,7 @@ class VllmRunner:
|
||||
|
||||
def generate_greedy(
|
||||
self,
|
||||
prompts: list[str] | list[torch.Tensor],
|
||||
prompts: list[str] | list[torch.Tensor] | list[list[int]],
|
||||
max_tokens: int,
|
||||
images: PromptImageInput | None = None,
|
||||
videos: PromptVideoInput | None = None,
|
||||
@@ -842,6 +967,319 @@ class VllmRunner:
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
class DPVllmRunner(VllmRunner):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
runner: RunnerOption = "auto",
|
||||
convert: ConvertOption = "auto",
|
||||
tokenizer_name: str | None = None,
|
||||
tokenizer_mode: str = "auto",
|
||||
max_model_len: int | None = 1024,
|
||||
dtype: str = "auto",
|
||||
disable_log_stats: bool = True,
|
||||
tensor_parallel_size: int = 1,
|
||||
block_size: int = 16,
|
||||
enable_chunked_prefill: bool = True,
|
||||
swap_space: int = 4,
|
||||
enforce_eager: bool | None = False,
|
||||
quantization: str | None = None,
|
||||
data_parallel_size: int = 2,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if data_parallel_size < 2:
|
||||
raise ValueError("DPVllmRunner requires `data_parallel_size >= 2`")
|
||||
|
||||
self._dp_size = data_parallel_size
|
||||
self._dp_parent_conns: list[Any] = []
|
||||
self._dp_processes: list[Any] = []
|
||||
self._dp_start_timeout = float(kwargs.pop("dp_start_timeout", _DP_RUNNER_START_TIMEOUT_SECONDS))
|
||||
self._dp_request_timeout = float(kwargs.pop("dp_request_timeout", _DP_RUNNER_REQUEST_TIMEOUT_SECONDS))
|
||||
|
||||
llm_kwargs = dict(
|
||||
model=model_name,
|
||||
runner=runner,
|
||||
convert=convert,
|
||||
tokenizer=tokenizer_name,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
trust_remote_code=True,
|
||||
dtype=dtype,
|
||||
swap_space=swap_space,
|
||||
enforce_eager=enforce_eager,
|
||||
disable_log_stats=disable_log_stats,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
max_model_len=max_model_len,
|
||||
block_size=block_size,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
quantization=quantization,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
self._start_data_parallel_workers(llm_kwargs)
|
||||
|
||||
@property
|
||||
def model(self) -> LLM:
|
||||
raise RuntimeError("Direct access to `runner.model` is not supported by `DPVllmRunner`.")
|
||||
|
||||
def _start_data_parallel_workers(self, llm_kwargs: dict[str, Any]) -> None:
|
||||
ctx = multiprocessing.get_context("spawn")
|
||||
master_port = get_open_port()
|
||||
|
||||
try:
|
||||
for dp_rank in range(self._dp_size):
|
||||
parent_conn, child_conn = ctx.Pipe()
|
||||
proc = ctx.Process(
|
||||
target=_run_vllm_runner_dp_worker,
|
||||
args=(child_conn, llm_kwargs, dp_rank, self._dp_size, master_port),
|
||||
)
|
||||
proc.start()
|
||||
child_conn.close()
|
||||
self._dp_parent_conns.append(parent_conn)
|
||||
self._dp_processes.append(proc)
|
||||
|
||||
for rank, conn in enumerate(self._dp_parent_conns):
|
||||
if not conn.poll(self._dp_start_timeout):
|
||||
raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to start")
|
||||
message = conn.recv()
|
||||
if message["status"] != "ready":
|
||||
raise RuntimeError(
|
||||
f"Failed to start data parallel worker {rank}:\n{message.get('traceback', 'unknown error')}"
|
||||
)
|
||||
except Exception:
|
||||
self._stop_data_parallel_workers()
|
||||
raise
|
||||
|
||||
def _stop_data_parallel_workers(self) -> None:
|
||||
for conn in self._dp_parent_conns:
|
||||
with contextlib.suppress(Exception):
|
||||
conn.send({"command": "shutdown"})
|
||||
|
||||
for proc in self._dp_processes:
|
||||
proc.join(timeout=_DP_RUNNER_SHUTDOWN_TIMEOUT_SECONDS)
|
||||
if proc.is_alive():
|
||||
proc.kill()
|
||||
proc.join(timeout=5)
|
||||
|
||||
for conn in self._dp_parent_conns:
|
||||
with contextlib.suppress(Exception):
|
||||
conn.close()
|
||||
|
||||
self._dp_parent_conns.clear()
|
||||
self._dp_processes.clear()
|
||||
|
||||
def _dispatch_prompt_command(
|
||||
self,
|
||||
command: str,
|
||||
prompts: list[str] | list[torch.Tensor] | list[list[int]],
|
||||
*,
|
||||
images: PromptImageInput | None = None,
|
||||
videos: PromptVideoInput | None = None,
|
||||
audios: PromptAudioInput | None = None,
|
||||
**payload: Any,
|
||||
) -> list[Any]:
|
||||
if not prompts:
|
||||
return []
|
||||
|
||||
shard_results: list[tuple[list[int], list[Any]]] = []
|
||||
shard_indices = _split_data_parallel_indices(len(prompts), self._dp_size)
|
||||
|
||||
for rank, conn in enumerate(self._dp_parent_conns):
|
||||
indices = shard_indices[rank]
|
||||
worker_indices = indices or [0]
|
||||
worker_prompts = _slice_list_inputs(prompts, worker_indices)
|
||||
conn.send(
|
||||
{
|
||||
"command": command,
|
||||
"indices": indices,
|
||||
"inputs": self.get_inputs(
|
||||
worker_prompts,
|
||||
images=_slice_optional_inputs(images, worker_indices),
|
||||
videos=_slice_optional_inputs(videos, worker_indices),
|
||||
audios=_slice_optional_inputs(audios, worker_indices),
|
||||
),
|
||||
"prompts": worker_prompts,
|
||||
**payload,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
for rank, conn in enumerate(self._dp_parent_conns):
|
||||
if not conn.poll(self._dp_request_timeout):
|
||||
raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to finish `{command}`")
|
||||
message = conn.recv()
|
||||
if message["status"] != "ok":
|
||||
raise RuntimeError(
|
||||
f"Data parallel worker {rank} failed during `{command}`:\n"
|
||||
f"{message.get('traceback', 'unknown error')}"
|
||||
)
|
||||
shard_results.append((message["indices"], message["result"]))
|
||||
except Exception:
|
||||
self._stop_data_parallel_workers()
|
||||
raise
|
||||
|
||||
return _merge_data_parallel_results(len(prompts), shard_results)
|
||||
|
||||
def _dispatch_text_command(self, command: str, prompts: list[str]) -> list[Any]:
|
||||
if not prompts:
|
||||
return []
|
||||
|
||||
shard_results: list[tuple[list[int], list[Any]]] = []
|
||||
shard_indices = _split_data_parallel_indices(len(prompts), self._dp_size)
|
||||
|
||||
for rank, conn in enumerate(self._dp_parent_conns):
|
||||
indices = shard_indices[rank]
|
||||
worker_indices = indices or [0]
|
||||
conn.send(
|
||||
{
|
||||
"command": command,
|
||||
"indices": indices,
|
||||
"prompts": _slice_list_inputs(prompts, worker_indices),
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
for rank, conn in enumerate(self._dp_parent_conns):
|
||||
if not conn.poll(self._dp_request_timeout):
|
||||
raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to finish `{command}`")
|
||||
message = conn.recv()
|
||||
if message["status"] != "ok":
|
||||
raise RuntimeError(
|
||||
f"Data parallel worker {rank} failed during `{command}`:\n"
|
||||
f"{message.get('traceback', 'unknown error')}"
|
||||
)
|
||||
shard_results.append((message["indices"], message["result"]))
|
||||
except Exception:
|
||||
self._stop_data_parallel_workers()
|
||||
raise
|
||||
|
||||
return _merge_data_parallel_results(len(prompts), shard_results)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: list[str] | list[torch.Tensor] | list[list[int]],
|
||||
sampling_params: SamplingParams,
|
||||
images: PromptImageInput | None = None,
|
||||
videos: PromptVideoInput | None = None,
|
||||
audios: PromptAudioInput | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[tuple[list[list[int]], list[str]]]:
|
||||
return self._dispatch_prompt_command(
|
||||
"generate",
|
||||
prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios,
|
||||
sampling_params=sampling_params,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
def generate_w_logprobs(
|
||||
self,
|
||||
prompts: list[str],
|
||||
sampling_params: SamplingParams,
|
||||
images: PromptImageInput | None = None,
|
||||
audios: PromptAudioInput | None = None,
|
||||
videos: PromptVideoInput | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]:
|
||||
toks_str_logsprobs_prompt_logprobs = self._dispatch_prompt_command(
|
||||
"generate_w_logprobs",
|
||||
prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios,
|
||||
sampling_params=sampling_params,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
return (
|
||||
[x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
|
||||
if sampling_params.prompt_logprobs is None
|
||||
else toks_str_logsprobs_prompt_logprobs
|
||||
)
|
||||
|
||||
def classify(self, prompts: list[str]) -> list[list[float]]:
|
||||
return self._dispatch_text_command("classify", prompts)
|
||||
|
||||
def embed(
|
||||
self,
|
||||
prompts: list[str],
|
||||
images: PromptImageInput | None = None,
|
||||
videos: PromptVideoInput | None = None,
|
||||
audios: PromptAudioInput | None = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> list[list[float]]:
|
||||
return self._dispatch_prompt_command(
|
||||
"embed",
|
||||
prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
def encode(self, prompts: list[str]) -> list[list[float]]:
|
||||
return self._dispatch_text_command("encode", prompts)
|
||||
|
||||
def reward(self, prompts: list[str]) -> list[list[float]]:
|
||||
return self._dispatch_text_command("reward", prompts)
|
||||
|
||||
def score(
|
||||
self,
|
||||
text_1: str | list[str],
|
||||
text_2: str | list[str],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> list[float]:
|
||||
normalized_text_1, normalized_text_2 = _normalize_score_inputs(text_1, text_2)
|
||||
if not normalized_text_1:
|
||||
return []
|
||||
|
||||
shard_results: list[tuple[list[int], list[Any]]] = []
|
||||
shard_indices = _split_data_parallel_indices(len(normalized_text_1), self._dp_size)
|
||||
|
||||
for rank, conn in enumerate(self._dp_parent_conns):
|
||||
indices = shard_indices[rank]
|
||||
worker_indices = indices or [0]
|
||||
conn.send(
|
||||
{
|
||||
"command": "score",
|
||||
"indices": indices,
|
||||
"text_1": _slice_list_inputs(normalized_text_1, worker_indices),
|
||||
"text_2": _slice_list_inputs(normalized_text_2, worker_indices),
|
||||
"args": args,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
for rank, conn in enumerate(self._dp_parent_conns):
|
||||
if not conn.poll(self._dp_request_timeout):
|
||||
raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to finish `score`")
|
||||
message = conn.recv()
|
||||
if message["status"] != "ok":
|
||||
raise RuntimeError(
|
||||
f"Data parallel worker {rank} failed during `score`:\n"
|
||||
f"{message.get('traceback', 'unknown error')}"
|
||||
)
|
||||
shard_results.append((message["indices"], message["result"]))
|
||||
except Exception:
|
||||
self._stop_data_parallel_workers()
|
||||
raise
|
||||
|
||||
return _merge_data_parallel_results(len(normalized_text_1), shard_results)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self._stop_data_parallel_workers()
|
||||
clear_ascend_config()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
DataParallelVllmRunner = DPVllmRunner
|
||||
|
||||
|
||||
class HfRunner:
|
||||
def get_default_device(self):
|
||||
return "cpu" if current_platform.is_cpu() else current_platform.device_type
|
||||
|
||||
Reference in New Issue
Block a user