bugfix(MC2): refactor the comm group of MC2 to be compatible with PP (#7291)
### What this PR does / why we need it?
This PR refactors the communication group of MC2 to keep it consistent
with vllm's EP group, making it compatible with PP.
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
4
.github/workflows/scripts/config.yaml
vendored
4
.github/workflows/scripts/config.yaml
vendored
@@ -127,8 +127,6 @@ e2e-multicard-2-cards:
|
|||||||
estimated_time: 180
|
estimated_time: 180
|
||||||
- name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_qwen3_w4a4_distributed_tp2
|
- name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_qwen3_w4a4_distributed_tp2
|
||||||
estimated_time: 202
|
estimated_time: 202
|
||||||
- name: tests/e2e/multicard/2-cards/test_pipeline_parallel.py
|
|
||||||
estimated_time: 357
|
|
||||||
- name: tests/e2e/multicard/2-cards/test_prefix_caching.py
|
- name: tests/e2e/multicard/2-cards/test_prefix_caching.py
|
||||||
estimated_time: 470
|
estimated_time: 470
|
||||||
- name: tests/e2e/multicard/2-cards/test_quantization.py
|
- name: tests/e2e/multicard/2-cards/test_quantization.py
|
||||||
@@ -165,3 +163,5 @@ e2e-multicard-4-cards:
|
|||||||
is_skipped: true
|
is_skipped: true
|
||||||
- name: tests/e2e/multicard/4-cards/spec_decode/test_mtp_qwen3_next.py
|
- name: tests/e2e/multicard/4-cards/spec_decode/test_mtp_qwen3_next.py
|
||||||
estimated_time: 1340
|
estimated_time: 1340
|
||||||
|
- name: tests/e2e/multicard/4-cards/test_pipeline_parallel.py
|
||||||
|
estimated_time: 357
|
||||||
|
|||||||
@@ -453,7 +453,6 @@ class RemoteEPDServer(RemoteOpenAIServer):
|
|||||||
self.env_dict.update(env_dict)
|
self.env_dict.update(env_dict)
|
||||||
|
|
||||||
self.env_dict["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
|
self.env_dict["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
|
||||||
self.env_dict["VLLM_USE_V1"] = "1"
|
|
||||||
self.env_dict["PYTORCH_NPU_ALLOC_CONF"] = "expandable_segments:True"
|
self.env_dict["PYTORCH_NPU_ALLOC_CONF"] = "expandable_segments:True"
|
||||||
self.env_dict["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
self.env_dict["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||||
|
|
||||||
@@ -626,6 +625,126 @@ class DisaggEpdProxy(RemoteEPDServer):
|
|||||||
super()._terminate_server()
|
super()._terminate_server()
|
||||||
|
|
||||||
|
|
||||||
|
_DP_RUNNER_START_TIMEOUT_SECONDS = 900.0
|
||||||
|
_DP_RUNNER_REQUEST_TIMEOUT_SECONDS = 900.0
|
||||||
|
_DP_RUNNER_SHUTDOWN_TIMEOUT_SECONDS = 30.0
|
||||||
|
|
||||||
|
|
||||||
|
def _split_data_parallel_indices(num_items: int, dp_size: int) -> list[list[int]]:
|
||||||
|
if num_items < 0:
|
||||||
|
raise ValueError("num_items must be non-negative")
|
||||||
|
if dp_size <= 0:
|
||||||
|
raise ValueError("dp_size must be positive")
|
||||||
|
|
||||||
|
floor = num_items // dp_size
|
||||||
|
remainder = num_items % dp_size
|
||||||
|
|
||||||
|
def start(rank: int) -> int:
|
||||||
|
return rank * floor + min(rank, remainder)
|
||||||
|
|
||||||
|
return [list(range(start(rank), start(rank + 1))) for rank in range(dp_size)]
|
||||||
|
|
||||||
|
|
||||||
|
def _slice_optional_inputs(inputs: PromptImageInput | PromptAudioInput | PromptVideoInput | None, indices: list[int]):
|
||||||
|
if inputs is None:
|
||||||
|
return None
|
||||||
|
return [inputs[index] for index in indices]
|
||||||
|
|
||||||
|
|
||||||
|
def _slice_list_inputs(items: list[Any], indices: list[int]) -> list[Any]:
|
||||||
|
return [items[index] for index in indices]
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_data_parallel_results(total_items: int, shard_results: list[tuple[list[int], list[Any]]]) -> list[Any]:
|
||||||
|
merged: list[Any] = [None] * total_items
|
||||||
|
for indices, results in shard_results:
|
||||||
|
if not indices:
|
||||||
|
continue
|
||||||
|
if len(indices) != len(results):
|
||||||
|
raise RuntimeError("Mismatched result count returned by data parallel worker")
|
||||||
|
for index, result in zip(indices, results):
|
||||||
|
merged[index] = result
|
||||||
|
|
||||||
|
if any(result is None for result in merged):
|
||||||
|
raise RuntimeError("Some data parallel results were not returned")
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_score_inputs(text_1: str | list[str], text_2: str | list[str]) -> tuple[list[str], list[str]]:
|
||||||
|
if isinstance(text_1, str) and isinstance(text_2, str):
|
||||||
|
return [text_1], [text_2]
|
||||||
|
if isinstance(text_1, str):
|
||||||
|
return [text_1] * len(text_2), list(text_2)
|
||||||
|
if isinstance(text_2, str):
|
||||||
|
return list(text_1), [text_2] * len(text_1)
|
||||||
|
if len(text_1) != len(text_2):
|
||||||
|
raise ValueError("`text_1` and `text_2` must have the same length")
|
||||||
|
return list(text_1), list(text_2)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_vllm_runner_dp_worker(conn, llm_kwargs: dict[str, Any], dp_rank: int, dp_size: int, master_port: int) -> None:
|
||||||
|
llm = None
|
||||||
|
try:
|
||||||
|
os.environ["VLLM_DP_RANK"] = str(dp_rank)
|
||||||
|
os.environ["VLLM_DP_RANK_LOCAL"] = str(dp_rank)
|
||||||
|
os.environ["VLLM_DP_SIZE"] = str(dp_size)
|
||||||
|
os.environ["VLLM_DP_MASTER_IP"] = "127.0.0.1"
|
||||||
|
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
|
||||||
|
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||||
|
|
||||||
|
llm = LLM(**llm_kwargs)
|
||||||
|
conn.send({"status": "ready", "rank": dp_rank})
|
||||||
|
|
||||||
|
while True:
|
||||||
|
request = conn.recv()
|
||||||
|
command = request["command"]
|
||||||
|
if command == "shutdown":
|
||||||
|
break
|
||||||
|
|
||||||
|
result: Any
|
||||||
|
if command == "generate":
|
||||||
|
req_outputs = llm.generate(
|
||||||
|
request["inputs"], sampling_params=request["sampling_params"], **request["kwargs"]
|
||||||
|
)
|
||||||
|
result = VllmRunner._finalize_generate_outputs(req_outputs)
|
||||||
|
elif command == "generate_w_logprobs":
|
||||||
|
req_outputs = llm.generate(
|
||||||
|
request["inputs"], sampling_params=request["sampling_params"], **request["kwargs"]
|
||||||
|
)
|
||||||
|
result = VllmRunner._final_steps_generate_w_logprobs(req_outputs)
|
||||||
|
elif command == "classify":
|
||||||
|
req_outputs = llm.classify(request["prompts"])
|
||||||
|
result = [req_output.outputs.probs for req_output in req_outputs]
|
||||||
|
elif command == "embed":
|
||||||
|
req_outputs = llm.embed(request["inputs"], *request["args"], **request["kwargs"])
|
||||||
|
result = [req_output.outputs.embedding for req_output in req_outputs]
|
||||||
|
elif command == "encode":
|
||||||
|
req_outputs = llm.encode(request["prompts"])
|
||||||
|
result = [req_output.outputs.data for req_output in req_outputs]
|
||||||
|
elif command == "reward":
|
||||||
|
req_outputs = llm.reward(request["prompts"])
|
||||||
|
result = [req_output.outputs.data for req_output in req_outputs]
|
||||||
|
elif command == "score":
|
||||||
|
req_outputs = llm.score(request["text_1"], request["text_2"], *request["args"], **request["kwargs"])
|
||||||
|
result = [req_output.outputs.score for req_output in req_outputs]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported data parallel command: {command}")
|
||||||
|
|
||||||
|
conn.send({"status": "ok", "rank": dp_rank, "indices": request["indices"], "result": result})
|
||||||
|
except Exception:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
conn.send({"status": "error", "rank": dp_rank, "traceback": traceback.format_exc()})
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if llm is not None:
|
||||||
|
del llm
|
||||||
|
clear_ascend_config()
|
||||||
|
cleanup_dist_env_and_memory()
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
class VllmRunner:
|
class VllmRunner:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -645,6 +764,10 @@ class VllmRunner:
|
|||||||
quantization: str | None = None,
|
quantization: str | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
data_parallel_size = int(kwargs.get("data_parallel_size", 1))
|
||||||
|
if data_parallel_size > 1:
|
||||||
|
raise ValueError("VllmRunner does not support `data_parallel_size > 1`; use `DPVllmRunner` instead.")
|
||||||
|
|
||||||
self.model = LLM(
|
self.model = LLM(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
runner=runner,
|
runner=runner,
|
||||||
@@ -664,6 +787,22 @@ class VllmRunner:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _finalize_generate_outputs(req_outputs: list[RequestOutput]) -> list[tuple[list[list[int]], list[str]]]:
|
||||||
|
outputs: list[tuple[list[list[int]], list[str]]] = []
|
||||||
|
for req_output in req_outputs:
|
||||||
|
prompt_str = req_output.prompt
|
||||||
|
prompt_ids = req_output.prompt_token_ids
|
||||||
|
req_sample_output_ids: list[list[int]] = []
|
||||||
|
req_sample_output_strs: list[str] = []
|
||||||
|
for sample in req_output.outputs:
|
||||||
|
output_str = sample.text
|
||||||
|
output_ids = list(sample.token_ids)
|
||||||
|
req_sample_output_ids.append(prompt_ids + output_ids)
|
||||||
|
req_sample_output_strs.append((prompt_str or "") + output_str)
|
||||||
|
outputs.append((req_sample_output_ids, req_sample_output_strs))
|
||||||
|
return outputs
|
||||||
|
|
||||||
def get_inputs(
|
def get_inputs(
|
||||||
self,
|
self,
|
||||||
prompts: list[str] | list[torch.Tensor] | list[int],
|
prompts: list[str] | list[torch.Tensor] | list[int],
|
||||||
@@ -698,7 +837,7 @@ class VllmRunner:
|
|||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompts: list[str] | list[torch.Tensor],
|
prompts: list[str] | list[torch.Tensor] | list[list[int]],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
images: PromptImageInput | None = None,
|
images: PromptImageInput | None = None,
|
||||||
videos: PromptVideoInput | None = None,
|
videos: PromptVideoInput | None = None,
|
||||||
@@ -706,22 +845,8 @@ class VllmRunner:
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[tuple[list[list[int]], list[str]]]:
|
) -> list[tuple[list[list[int]], list[str]]]:
|
||||||
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
||||||
|
|
||||||
req_outputs = self.model.generate(inputs, sampling_params=sampling_params, **kwargs)
|
req_outputs = self.model.generate(inputs, sampling_params=sampling_params, **kwargs)
|
||||||
|
return self._finalize_generate_outputs(req_outputs)
|
||||||
outputs: list[tuple[list[list[int]], list[str]]] = []
|
|
||||||
for req_output in req_outputs:
|
|
||||||
prompt_str = req_output.prompt
|
|
||||||
prompt_ids = req_output.prompt_token_ids
|
|
||||||
req_sample_output_ids: list[list[int]] = []
|
|
||||||
req_sample_output_strs: list[str] = []
|
|
||||||
for sample in req_output.outputs:
|
|
||||||
output_str = sample.text
|
|
||||||
output_ids = list(sample.token_ids)
|
|
||||||
req_sample_output_ids.append(prompt_ids + output_ids)
|
|
||||||
req_sample_output_strs.append((prompt_str or "") + output_str)
|
|
||||||
outputs.append((req_sample_output_ids, req_sample_output_strs))
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _final_steps_generate_w_logprobs(
|
def _final_steps_generate_w_logprobs(
|
||||||
@@ -760,7 +885,7 @@ class VllmRunner:
|
|||||||
|
|
||||||
def generate_greedy(
|
def generate_greedy(
|
||||||
self,
|
self,
|
||||||
prompts: list[str] | list[torch.Tensor],
|
prompts: list[str] | list[torch.Tensor] | list[list[int]],
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
images: PromptImageInput | None = None,
|
images: PromptImageInput | None = None,
|
||||||
videos: PromptVideoInput | None = None,
|
videos: PromptVideoInput | None = None,
|
||||||
@@ -842,6 +967,319 @@ class VllmRunner:
|
|||||||
cleanup_dist_env_and_memory()
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
|
||||||
|
class DPVllmRunner(VllmRunner):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
runner: RunnerOption = "auto",
|
||||||
|
convert: ConvertOption = "auto",
|
||||||
|
tokenizer_name: str | None = None,
|
||||||
|
tokenizer_mode: str = "auto",
|
||||||
|
max_model_len: int | None = 1024,
|
||||||
|
dtype: str = "auto",
|
||||||
|
disable_log_stats: bool = True,
|
||||||
|
tensor_parallel_size: int = 1,
|
||||||
|
block_size: int = 16,
|
||||||
|
enable_chunked_prefill: bool = True,
|
||||||
|
swap_space: int = 4,
|
||||||
|
enforce_eager: bool | None = False,
|
||||||
|
quantization: str | None = None,
|
||||||
|
data_parallel_size: int = 2,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
if data_parallel_size < 2:
|
||||||
|
raise ValueError("DPVllmRunner requires `data_parallel_size >= 2`")
|
||||||
|
|
||||||
|
self._dp_size = data_parallel_size
|
||||||
|
self._dp_parent_conns: list[Any] = []
|
||||||
|
self._dp_processes: list[Any] = []
|
||||||
|
self._dp_start_timeout = float(kwargs.pop("dp_start_timeout", _DP_RUNNER_START_TIMEOUT_SECONDS))
|
||||||
|
self._dp_request_timeout = float(kwargs.pop("dp_request_timeout", _DP_RUNNER_REQUEST_TIMEOUT_SECONDS))
|
||||||
|
|
||||||
|
llm_kwargs = dict(
|
||||||
|
model=model_name,
|
||||||
|
runner=runner,
|
||||||
|
convert=convert,
|
||||||
|
tokenizer=tokenizer_name,
|
||||||
|
tokenizer_mode=tokenizer_mode,
|
||||||
|
trust_remote_code=True,
|
||||||
|
dtype=dtype,
|
||||||
|
swap_space=swap_space,
|
||||||
|
enforce_eager=enforce_eager,
|
||||||
|
disable_log_stats=disable_log_stats,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
block_size=block_size,
|
||||||
|
enable_chunked_prefill=enable_chunked_prefill,
|
||||||
|
quantization=quantization,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
cleanup_dist_env_and_memory()
|
||||||
|
self._start_data_parallel_workers(llm_kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self) -> LLM:
|
||||||
|
raise RuntimeError("Direct access to `runner.model` is not supported by `DPVllmRunner`.")
|
||||||
|
|
||||||
|
def _start_data_parallel_workers(self, llm_kwargs: dict[str, Any]) -> None:
|
||||||
|
ctx = multiprocessing.get_context("spawn")
|
||||||
|
master_port = get_open_port()
|
||||||
|
|
||||||
|
try:
|
||||||
|
for dp_rank in range(self._dp_size):
|
||||||
|
parent_conn, child_conn = ctx.Pipe()
|
||||||
|
proc = ctx.Process(
|
||||||
|
target=_run_vllm_runner_dp_worker,
|
||||||
|
args=(child_conn, llm_kwargs, dp_rank, self._dp_size, master_port),
|
||||||
|
)
|
||||||
|
proc.start()
|
||||||
|
child_conn.close()
|
||||||
|
self._dp_parent_conns.append(parent_conn)
|
||||||
|
self._dp_processes.append(proc)
|
||||||
|
|
||||||
|
for rank, conn in enumerate(self._dp_parent_conns):
|
||||||
|
if not conn.poll(self._dp_start_timeout):
|
||||||
|
raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to start")
|
||||||
|
message = conn.recv()
|
||||||
|
if message["status"] != "ready":
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to start data parallel worker {rank}:\n{message.get('traceback', 'unknown error')}"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
self._stop_data_parallel_workers()
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _stop_data_parallel_workers(self) -> None:
|
||||||
|
for conn in self._dp_parent_conns:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
conn.send({"command": "shutdown"})
|
||||||
|
|
||||||
|
for proc in self._dp_processes:
|
||||||
|
proc.join(timeout=_DP_RUNNER_SHUTDOWN_TIMEOUT_SECONDS)
|
||||||
|
if proc.is_alive():
|
||||||
|
proc.kill()
|
||||||
|
proc.join(timeout=5)
|
||||||
|
|
||||||
|
for conn in self._dp_parent_conns:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
self._dp_parent_conns.clear()
|
||||||
|
self._dp_processes.clear()
|
||||||
|
|
||||||
|
def _dispatch_prompt_command(
|
||||||
|
self,
|
||||||
|
command: str,
|
||||||
|
prompts: list[str] | list[torch.Tensor] | list[list[int]],
|
||||||
|
*,
|
||||||
|
images: PromptImageInput | None = None,
|
||||||
|
videos: PromptVideoInput | None = None,
|
||||||
|
audios: PromptAudioInput | None = None,
|
||||||
|
**payload: Any,
|
||||||
|
) -> list[Any]:
|
||||||
|
if not prompts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
shard_results: list[tuple[list[int], list[Any]]] = []
|
||||||
|
shard_indices = _split_data_parallel_indices(len(prompts), self._dp_size)
|
||||||
|
|
||||||
|
for rank, conn in enumerate(self._dp_parent_conns):
|
||||||
|
indices = shard_indices[rank]
|
||||||
|
worker_indices = indices or [0]
|
||||||
|
worker_prompts = _slice_list_inputs(prompts, worker_indices)
|
||||||
|
conn.send(
|
||||||
|
{
|
||||||
|
"command": command,
|
||||||
|
"indices": indices,
|
||||||
|
"inputs": self.get_inputs(
|
||||||
|
worker_prompts,
|
||||||
|
images=_slice_optional_inputs(images, worker_indices),
|
||||||
|
videos=_slice_optional_inputs(videos, worker_indices),
|
||||||
|
audios=_slice_optional_inputs(audios, worker_indices),
|
||||||
|
),
|
||||||
|
"prompts": worker_prompts,
|
||||||
|
**payload,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for rank, conn in enumerate(self._dp_parent_conns):
|
||||||
|
if not conn.poll(self._dp_request_timeout):
|
||||||
|
raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to finish `{command}`")
|
||||||
|
message = conn.recv()
|
||||||
|
if message["status"] != "ok":
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Data parallel worker {rank} failed during `{command}`:\n"
|
||||||
|
f"{message.get('traceback', 'unknown error')}"
|
||||||
|
)
|
||||||
|
shard_results.append((message["indices"], message["result"]))
|
||||||
|
except Exception:
|
||||||
|
self._stop_data_parallel_workers()
|
||||||
|
raise
|
||||||
|
|
||||||
|
return _merge_data_parallel_results(len(prompts), shard_results)
|
||||||
|
|
||||||
|
def _dispatch_text_command(self, command: str, prompts: list[str]) -> list[Any]:
|
||||||
|
if not prompts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
shard_results: list[tuple[list[int], list[Any]]] = []
|
||||||
|
shard_indices = _split_data_parallel_indices(len(prompts), self._dp_size)
|
||||||
|
|
||||||
|
for rank, conn in enumerate(self._dp_parent_conns):
|
||||||
|
indices = shard_indices[rank]
|
||||||
|
worker_indices = indices or [0]
|
||||||
|
conn.send(
|
||||||
|
{
|
||||||
|
"command": command,
|
||||||
|
"indices": indices,
|
||||||
|
"prompts": _slice_list_inputs(prompts, worker_indices),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for rank, conn in enumerate(self._dp_parent_conns):
|
||||||
|
if not conn.poll(self._dp_request_timeout):
|
||||||
|
raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to finish `{command}`")
|
||||||
|
message = conn.recv()
|
||||||
|
if message["status"] != "ok":
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Data parallel worker {rank} failed during `{command}`:\n"
|
||||||
|
f"{message.get('traceback', 'unknown error')}"
|
||||||
|
)
|
||||||
|
shard_results.append((message["indices"], message["result"]))
|
||||||
|
except Exception:
|
||||||
|
self._stop_data_parallel_workers()
|
||||||
|
raise
|
||||||
|
|
||||||
|
return _merge_data_parallel_results(len(prompts), shard_results)
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompts: list[str] | list[torch.Tensor] | list[list[int]],
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
images: PromptImageInput | None = None,
|
||||||
|
videos: PromptVideoInput | None = None,
|
||||||
|
audios: PromptAudioInput | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[tuple[list[list[int]], list[str]]]:
|
||||||
|
return self._dispatch_prompt_command(
|
||||||
|
"generate",
|
||||||
|
prompts,
|
||||||
|
images=images,
|
||||||
|
videos=videos,
|
||||||
|
audios=audios,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_w_logprobs(
|
||||||
|
self,
|
||||||
|
prompts: list[str],
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
images: PromptImageInput | None = None,
|
||||||
|
audios: PromptAudioInput | None = None,
|
||||||
|
videos: PromptVideoInput | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]:
|
||||||
|
toks_str_logsprobs_prompt_logprobs = self._dispatch_prompt_command(
|
||||||
|
"generate_w_logprobs",
|
||||||
|
prompts,
|
||||||
|
images=images,
|
||||||
|
videos=videos,
|
||||||
|
audios=audios,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
[x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
|
||||||
|
if sampling_params.prompt_logprobs is None
|
||||||
|
else toks_str_logsprobs_prompt_logprobs
|
||||||
|
)
|
||||||
|
|
||||||
|
def classify(self, prompts: list[str]) -> list[list[float]]:
|
||||||
|
return self._dispatch_text_command("classify", prompts)
|
||||||
|
|
||||||
|
def embed(
|
||||||
|
self,
|
||||||
|
prompts: list[str],
|
||||||
|
images: PromptImageInput | None = None,
|
||||||
|
videos: PromptVideoInput | None = None,
|
||||||
|
audios: PromptAudioInput | None = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> list[list[float]]:
|
||||||
|
return self._dispatch_prompt_command(
|
||||||
|
"embed",
|
||||||
|
prompts,
|
||||||
|
images=images,
|
||||||
|
videos=videos,
|
||||||
|
audios=audios,
|
||||||
|
args=args,
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode(self, prompts: list[str]) -> list[list[float]]:
|
||||||
|
return self._dispatch_text_command("encode", prompts)
|
||||||
|
|
||||||
|
def reward(self, prompts: list[str]) -> list[list[float]]:
|
||||||
|
return self._dispatch_text_command("reward", prompts)
|
||||||
|
|
||||||
|
def score(
|
||||||
|
self,
|
||||||
|
text_1: str | list[str],
|
||||||
|
text_2: str | list[str],
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> list[float]:
|
||||||
|
normalized_text_1, normalized_text_2 = _normalize_score_inputs(text_1, text_2)
|
||||||
|
if not normalized_text_1:
|
||||||
|
return []
|
||||||
|
|
||||||
|
shard_results: list[tuple[list[int], list[Any]]] = []
|
||||||
|
shard_indices = _split_data_parallel_indices(len(normalized_text_1), self._dp_size)
|
||||||
|
|
||||||
|
for rank, conn in enumerate(self._dp_parent_conns):
|
||||||
|
indices = shard_indices[rank]
|
||||||
|
worker_indices = indices or [0]
|
||||||
|
conn.send(
|
||||||
|
{
|
||||||
|
"command": "score",
|
||||||
|
"indices": indices,
|
||||||
|
"text_1": _slice_list_inputs(normalized_text_1, worker_indices),
|
||||||
|
"text_2": _slice_list_inputs(normalized_text_2, worker_indices),
|
||||||
|
"args": args,
|
||||||
|
"kwargs": kwargs,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for rank, conn in enumerate(self._dp_parent_conns):
|
||||||
|
if not conn.poll(self._dp_request_timeout):
|
||||||
|
raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to finish `score`")
|
||||||
|
message = conn.recv()
|
||||||
|
if message["status"] != "ok":
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Data parallel worker {rank} failed during `score`:\n"
|
||||||
|
f"{message.get('traceback', 'unknown error')}"
|
||||||
|
)
|
||||||
|
shard_results.append((message["indices"], message["result"]))
|
||||||
|
except Exception:
|
||||||
|
self._stop_data_parallel_workers()
|
||||||
|
raise
|
||||||
|
|
||||||
|
return _merge_data_parallel_results(len(normalized_text_1), shard_results)
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
self._stop_data_parallel_workers()
|
||||||
|
clear_ascend_config()
|
||||||
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
|
||||||
|
DataParallelVllmRunner = DPVllmRunner
|
||||||
|
|
||||||
|
|
||||||
class HfRunner:
|
class HfRunner:
|
||||||
def get_default_device(self):
|
def get_default_device(self):
|
||||||
return "cpu" if current_platform.is_cpu() else current_platform.device_type
|
return "cpu" if current_platform.is_cpu() else current_platform.device_type
|
||||||
|
|||||||
@@ -1,49 +0,0 @@
|
|||||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
||||||
# Copyright 2023 The vLLM team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# This file is a part of the vllm-ascend project.
|
|
||||||
#
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from tests.e2e.conftest import VllmRunner
|
|
||||||
|
|
||||||
MODELS = [
|
|
||||||
"Qwen/Qwen3-0.6B",
|
|
||||||
"deepseek-ai/DeepSeek-V2-Lite-Chat",
|
|
||||||
]
|
|
||||||
|
|
||||||
TENSOR_PARALLELS = [1]
|
|
||||||
PIPELINE_PARALLELS = [2]
|
|
||||||
DIST_EXECUTOR_BACKEND = ["mp", "ray"]
|
|
||||||
|
|
||||||
prompts = [
|
|
||||||
"Hello, my name is",
|
|
||||||
"The future of AI is",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
|
||||||
@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS)
|
|
||||||
@pytest.mark.parametrize("pp_size", PIPELINE_PARALLELS)
|
|
||||||
@pytest.mark.parametrize("distributed_executor_backend", DIST_EXECUTOR_BACKEND)
|
|
||||||
def test_models_pp2(model: str, tp_size: int, pp_size: int, distributed_executor_backend: str) -> None:
|
|
||||||
with VllmRunner(
|
|
||||||
model,
|
|
||||||
tensor_parallel_size=tp_size,
|
|
||||||
pipeline_parallel_size=pp_size,
|
|
||||||
cudagraph_capture_sizes=[1, 2, 4, 8],
|
|
||||||
distributed_executor_backend=distributed_executor_backend,
|
|
||||||
gpu_memory_utilization=0.7,
|
|
||||||
) as vllm_model:
|
|
||||||
vllm_model.generate_greedy(prompts, 64)
|
|
||||||
88
tests/e2e/multicard/4-cards/test_pipeline_parallel.py
Normal file
88
tests/e2e/multicard/4-cards/test_pipeline_parallel.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
#
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tests.e2e.conftest import DPVllmRunner, VllmRunner, wait_until_npu_memory_free
|
||||||
|
from tests.e2e.model_utils import check_outputs_equal
|
||||||
|
|
||||||
|
DS3 = "deepseek-ai/DeepSeek-V2-Lite-Chat"
|
||||||
|
MODELS = [
|
||||||
|
DS3,
|
||||||
|
]
|
||||||
|
MOE_MODELS = [
|
||||||
|
DS3,
|
||||||
|
]
|
||||||
|
|
||||||
|
DATA_PARALLELS = [2]
|
||||||
|
TENSOR_PARALLELS = [1,2]
|
||||||
|
PIPELINE_PARALLELS = [2]
|
||||||
|
DIST_EXECUTOR_BACKEND = ["mp", "ray"]
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
GOLDEN = [([100000, 17464, 11, 601, 1210, 317, 46462, 608, 245, 4541, 7712, 13, 2682, 6207, 317, 276, 2774, 340, 366, 254, 1608, 2784], 'Hello, my name is***** am a computer expert. My goal is to provide you with the best experience'), ([100000, 549, 3680, 280, 20838, 317, 6464, 11, 548, 359, 487, 82, 441, 1673, 895, 10694, 13, 1733, 20838, 5495, 11106, 276], 'The future of AI is bright, but it’s not without its challenges. As AI technology continues to')]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS)
|
||||||
|
@pytest.mark.parametrize("pp_size", PIPELINE_PARALLELS)
|
||||||
|
@pytest.mark.parametrize("distributed_executor_backend", DIST_EXECUTOR_BACKEND)
|
||||||
|
@wait_until_npu_memory_free(target_free_percentage=0.6)
|
||||||
|
def test_models_pp2_tp2(model: str, tp_size: int, pp_size: int, distributed_executor_backend: str) -> None:
|
||||||
|
with VllmRunner(
|
||||||
|
model,
|
||||||
|
tensor_parallel_size=tp_size,
|
||||||
|
pipeline_parallel_size=pp_size,
|
||||||
|
cudagraph_capture_sizes=[1, 2, 4],
|
||||||
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
|
gpu_memory_utilization=0.7,
|
||||||
|
enable_expert_parallel=model in MOE_MODELS,
|
||||||
|
) as vllm_model:
|
||||||
|
outputs = vllm_model.generate_greedy(prompts, 16)
|
||||||
|
check_outputs_equal(
|
||||||
|
outputs_0_lst=outputs,
|
||||||
|
outputs_1_lst=GOLDEN,
|
||||||
|
name_0=f"{model}-tp{tp_size}pp{pp_size}",
|
||||||
|
name_1="GOLDEN",
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dp_size", DATA_PARALLELS)
|
||||||
|
@pytest.mark.parametrize("pp_size", PIPELINE_PARALLELS)
|
||||||
|
@pytest.mark.parametrize("distributed_executor_backend", DIST_EXECUTOR_BACKEND)
|
||||||
|
@wait_until_npu_memory_free(target_free_percentage=0.6)
|
||||||
|
def test_models_pp2_dp2(model: str, dp_size: int, pp_size: int, distributed_executor_backend: str) -> None:
|
||||||
|
with DPVllmRunner(
|
||||||
|
model,
|
||||||
|
data_parallel_size=dp_size,
|
||||||
|
pipeline_parallel_size=pp_size,
|
||||||
|
cudagraph_capture_sizes=[1, 2, 4],
|
||||||
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
|
gpu_memory_utilization=0.7,
|
||||||
|
enable_expert_parallel=model in MOE_MODELS,
|
||||||
|
) as vllm_model:
|
||||||
|
outputs = vllm_model.generate_greedy(prompts, 16)
|
||||||
|
check_outputs_equal(
|
||||||
|
outputs_0_lst=outputs,
|
||||||
|
outputs_1_lst=GOLDEN,
|
||||||
|
name_0=f"{model}-dp{dp_size}pp{pp_size}",
|
||||||
|
name_1="GOLDEN",
|
||||||
|
)
|
||||||
|
|
||||||
@@ -38,19 +38,16 @@ def init_ascend_model_parallel(
|
|||||||
global_tp_size = parallel_config.tensor_parallel_size
|
global_tp_size = parallel_config.tensor_parallel_size
|
||||||
global_dp_size = parallel_config.data_parallel_size
|
global_dp_size = parallel_config.data_parallel_size
|
||||||
global_pp_size = parallel_config.pipeline_parallel_size
|
global_pp_size = parallel_config.pipeline_parallel_size
|
||||||
|
global_pcp_size = parallel_config.prefill_context_parallel_size
|
||||||
|
|
||||||
# The layout of all ranks: ExternalDP * EP
|
# The layout of all ranks: ExternalDP * EP
|
||||||
# ExternalDP is the data parallel group that is not part of the model,
|
# ExternalDP is the data parallel group that is not part of the model,
|
||||||
# every dp rank can generate independently (in verl integration).
|
# every dp rank can generate independently (in verl integration).
|
||||||
all_ranks = torch.arange(world_size).reshape(
|
all_ranks = torch.arange(world_size).reshape(
|
||||||
-1, global_dp_size * parallel_config.prefill_context_parallel_size * global_tp_size
|
|
||||||
)
|
|
||||||
# TODO: all_ranks should be the same as vllm_all_ranks, all_ranks needs to be removed in the future.
|
|
||||||
vllm_all_ranks = torch.arange(world_size).reshape(
|
|
||||||
-1,
|
-1,
|
||||||
global_dp_size,
|
global_dp_size,
|
||||||
global_pp_size,
|
global_pp_size,
|
||||||
parallel_config.prefill_context_parallel_size,
|
global_pcp_size,
|
||||||
global_tp_size,
|
global_tp_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -59,7 +56,6 @@ def init_ascend_model_parallel(
|
|||||||
global _P_TP
|
global _P_TP
|
||||||
assert _P_TP is None, "distributed prefill tensor parallel group is already initialized"
|
assert _P_TP is None, "distributed prefill tensor parallel group is already initialized"
|
||||||
prefill_tensor_model_parallel_size = pd_tp_ratio
|
prefill_tensor_model_parallel_size = pd_tp_ratio
|
||||||
pcp_size = parallel_config.prefill_context_parallel_size
|
|
||||||
# divide alltoall groups
|
# divide alltoall groups
|
||||||
if pd_head_ratio > 1 and get_current_vllm_config().kv_transfer_config.is_kv_producer:
|
if pd_head_ratio > 1 and get_current_vllm_config().kv_transfer_config.is_kv_producer:
|
||||||
num_head_replica = get_ascend_config().num_head_replica
|
num_head_replica = get_ascend_config().num_head_replica
|
||||||
@@ -68,13 +64,16 @@ def init_ascend_model_parallel(
|
|||||||
group_ranks = all_ranks.view(-1, prefill_tensor_model_parallel_size).unbind(0)
|
group_ranks = all_ranks.view(-1, prefill_tensor_model_parallel_size).unbind(0)
|
||||||
else:
|
else:
|
||||||
group_ranks = all_ranks.clone().view(
|
group_ranks = all_ranks.clone().view(
|
||||||
global_dp_size * pcp_size, -1, num_head_replica
|
global_dp_size * global_pp_size * global_pcp_size, -1, num_head_replica
|
||||||
) # [DP_size, num_head, num_head_replica]
|
) # [DP_size, num_head, num_head_replica]
|
||||||
group_ranks = group_ranks.permute(0, 2, 1)
|
group_ranks = group_ranks.permute(0, 2, 1)
|
||||||
group_ranks = group_ranks.reshape(-1, group_ranks.size(-1)) # [DP_size * num_head_replica, num_head]
|
group_ranks = group_ranks.reshape(-1, group_ranks.size(-1)) # [DP_size * num_head_replica, num_head]
|
||||||
alltoall_group_size = group_ranks.size(-1) // remote_tp_size
|
alltoall_group_size = group_ranks.size(-1) // remote_tp_size
|
||||||
group_ranks = group_ranks.unsqueeze(-1).view(
|
group_ranks = group_ranks.unsqueeze(-1).view(
|
||||||
global_dp_size * pcp_size, num_head_replica, -1, alltoall_group_size
|
global_dp_size * global_pp_size * global_pcp_size,
|
||||||
|
num_head_replica,
|
||||||
|
-1,
|
||||||
|
alltoall_group_size,
|
||||||
) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size]
|
) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size]
|
||||||
group_ranks = group_ranks.reshape(-1, alltoall_group_size).unbind(0)
|
group_ranks = group_ranks.reshape(-1, alltoall_group_size).unbind(0)
|
||||||
group_ranks = [x.tolist() for x in group_ranks]
|
group_ranks = [x.tolist() for x in group_ranks]
|
||||||
@@ -82,10 +81,18 @@ def init_ascend_model_parallel(
|
|||||||
num = next((i for i, ranks in enumerate(group_ranks) if local_rank in ranks), None)
|
num = next((i for i, ranks in enumerate(group_ranks) if local_rank in ranks), None)
|
||||||
_P_TP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name=f"p_tp_{num}")
|
_P_TP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name=f"p_tp_{num}")
|
||||||
|
|
||||||
global _MC2
|
# EP like group ranks
|
||||||
group_ranks = all_ranks.unbind(0)
|
group_ranks = (
|
||||||
|
all_ranks.transpose(1, 2)
|
||||||
|
.reshape(
|
||||||
|
-1,
|
||||||
|
global_dp_size * global_pcp_size * global_tp_size,
|
||||||
|
)
|
||||||
|
.unbind(0)
|
||||||
|
)
|
||||||
group_ranks = [x.tolist() for x in group_ranks]
|
group_ranks = [x.tolist() for x in group_ranks]
|
||||||
|
|
||||||
|
global _MC2
|
||||||
_MC2 = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name="mc2")
|
_MC2 = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name="mc2")
|
||||||
|
|
||||||
if get_ascend_config().eplb_config.dynamic_eplb:
|
if get_ascend_config().eplb_config.dynamic_eplb:
|
||||||
@@ -94,6 +101,12 @@ def init_ascend_model_parallel(
|
|||||||
group_ranks, get_world_group().local_rank, backend, group_name="dynamic_eplb"
|
group_ranks, get_world_group().local_rank, backend, group_name="dynamic_eplb"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if get_ascend_config().multistream_overlap_gate:
|
||||||
|
global _FC3_QUANT_X
|
||||||
|
_FC3_QUANT_X = init_model_parallel_group(
|
||||||
|
group_ranks, get_world_group().local_rank, backend, group_name="fc3_quant_x"
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize fine-grained TP process groups on Ascend for four components:
|
# Initialize fine-grained TP process groups on Ascend for four components:
|
||||||
# 1. LM Head: output logits projection (`lmhead_tensor_parallel_size`)
|
# 1. LM Head: output logits projection (`lmhead_tensor_parallel_size`)
|
||||||
# 2. O Proj: attention output projection (`oproj_tensor_parallel_size`)
|
# 2. O Proj: attention output projection (`oproj_tensor_parallel_size`)
|
||||||
@@ -182,7 +195,7 @@ def init_ascend_model_parallel(
|
|||||||
# 2. If it is not None, and the module tp_group is same as the global tp_group.
|
# 2. If it is not None, and the module tp_group is same as the global tp_group.
|
||||||
# 3. If it is not None, and the module tp_group is different from the global tp_group.(eg. flashcomm2_otp)
|
# 3. If it is not None, and the module tp_group is different from the global tp_group.(eg. flashcomm2_otp)
|
||||||
group_ranks = []
|
group_ranks = []
|
||||||
pp_group_ranks = vllm_all_ranks.transpose(2, 4).reshape(-1, global_pp_size)
|
pp_group_ranks = all_ranks.transpose(2, 4).reshape(-1, global_pp_size)
|
||||||
if module_tp_group_ranks is None:
|
if module_tp_group_ranks is None:
|
||||||
# If it is None, then the TP_size of this shard weight is 1.
|
# If it is None, then the TP_size of this shard weight is 1.
|
||||||
shard_weight_group_ranks = pp_group_ranks.transpose(0, 1).unbind(0)
|
shard_weight_group_ranks = pp_group_ranks.transpose(0, 1).unbind(0)
|
||||||
@@ -209,17 +222,9 @@ def init_ascend_model_parallel(
|
|||||||
_SHARD_WEIGHT = create_shard_weight_group(None)
|
_SHARD_WEIGHT = create_shard_weight_group(None)
|
||||||
else:
|
else:
|
||||||
# For standard tp, use global tp group_ranks
|
# For standard tp, use global tp group_ranks
|
||||||
tp_group_ranks = vllm_all_ranks.view(-1, global_tp_size)
|
tp_group_ranks = all_ranks.view(-1, global_tp_size)
|
||||||
_SHARD_WEIGHT = create_shard_weight_group(tp_group_ranks)
|
_SHARD_WEIGHT = create_shard_weight_group(tp_group_ranks)
|
||||||
|
|
||||||
if get_ascend_config().multistream_overlap_gate:
|
|
||||||
global _FC3_QUANT_X
|
|
||||||
group_ranks = all_ranks.unbind(0)
|
|
||||||
group_ranks = [x.tolist() for x in group_ranks]
|
|
||||||
_FC3_QUANT_X = init_model_parallel_group(
|
|
||||||
group_ranks, get_world_group().local_rank, backend, group_name="fc3_quant_x"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def model_parallel_initialized():
|
def model_parallel_initialized():
|
||||||
return _MC2 is not None
|
return _MC2 is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user