Sync from v0.13
This commit is contained in:
78
tests/v1/tpu/test_spmd_model_weight_loading.py
Normal file
78
tests/v1/tpu/test_spmd_model_weight_loading.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch_xla.distributed.spmd as xs
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
from vllm.config import set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (
|
||||
ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.model_executor.model_loader.tpu import TPUModelLoader
|
||||
|
||||
|
||||
def _setup_environment(model):
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
temp_file = tempfile.mkstemp()[1]
|
||||
init_distributed_environment(
|
||||
1,
|
||||
0,
|
||||
local_rank=0,
|
||||
distributed_init_method=f"file://{temp_file}",
|
||||
backend="gloo",
|
||||
)
|
||||
# Under single worker mode, full model is init first and then
|
||||
# partitioned using GSPMD.
|
||||
ensure_model_parallel_initialized(1, 1)
|
||||
return vllm_config
|
||||
|
||||
|
||||
MESH = None
|
||||
|
||||
|
||||
def _get_spmd_mesh():
|
||||
global MESH
|
||||
if MESH is None:
|
||||
xr.use_spmd()
|
||||
num_devices = xr.global_runtime_device_count()
|
||||
mesh_shape = (num_devices, 1)
|
||||
device_ids = np.array(range(num_devices))
|
||||
MESH = xs.Mesh(device_ids, mesh_shape, ("x", "y"))
|
||||
return MESH
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"Qwen/Qwen2-1.5B-Instruct",
|
||||
# Skip large models due to CI runner disk space limitations
|
||||
# "meta-llama/Llama-3.1-8B-Instruct",
|
||||
# "meta-llama/Llama-3.1-70B-Instruct",
|
||||
],
|
||||
)
|
||||
def test_tpu_model_loader(model):
|
||||
# Skip the 70B test if there are less than 8 chips
|
||||
# TODO: Query using torch xla API, the query API is not working
|
||||
# with SPMD now. However, This test is running under SPMD mode.
|
||||
if "70B" in model and xr.global_runtime_device_count() < 8:
|
||||
pytest.skip(
|
||||
"Skipping 70B model if the TPU VM has less than 8 chips to \
|
||||
avoid OOM."
|
||||
)
|
||||
|
||||
vllm_config = _setup_environment(model)
|
||||
loader = TPUModelLoader(load_config=vllm_config.load_config)
|
||||
mesh = _get_spmd_mesh()
|
||||
model = loader.load_model(vllm_config, vllm_config.model_config, mesh)
|
||||
del model
|
||||
gc.collect()
|
||||
Reference in New Issue
Block a user