Add device support (#1607)
This commit is contained in:
@@ -423,6 +423,9 @@ class ScheduleBatch:
|
||||
# Stream
|
||||
has_stream: bool = False
|
||||
|
||||
# device
|
||||
device: str = "cuda"
|
||||
|
||||
# Has regex
|
||||
has_regex: bool = False
|
||||
|
||||
@@ -439,6 +442,7 @@ class ScheduleBatch:
|
||||
tree_cache=tree_cache,
|
||||
return_logprob=return_logprob,
|
||||
has_stream=has_stream,
|
||||
device=req_to_token_pool.device,
|
||||
has_regex=has_regex,
|
||||
)
|
||||
|
||||
|
||||
@@ -81,10 +81,11 @@ class ModelRunner:
|
||||
# Parse args
|
||||
self.model_config = model_config
|
||||
self.mem_fraction_static = mem_fraction_static
|
||||
self.device = server_args.device
|
||||
self.gpu_id = gpu_id
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
self.nccl_port = nccl_port
|
||||
self.dist_port = nccl_port
|
||||
self.server_args = server_args
|
||||
self.is_multimodal_model = is_multimodal_model(
|
||||
self.model_config.hf_config.architectures
|
||||
@@ -132,39 +133,45 @@ class ModelRunner:
|
||||
server_args.max_running_requests,
|
||||
server_args.max_total_tokens,
|
||||
)
|
||||
self.init_cublas()
|
||||
self.init_attention_backend()
|
||||
self.init_cuda_graphs()
|
||||
if self.device == "cuda":
|
||||
self.init_cublas()
|
||||
self.init_attention_backend()
|
||||
self.init_cuda_graphs()
|
||||
else:
|
||||
self.init_attention_backend()
|
||||
|
||||
def init_torch_distributed(self):
|
||||
logger.info("Init torch distributed begin.")
|
||||
# Init torch distributed
|
||||
torch.cuda.set_device(self.gpu_id)
|
||||
logger.info("Init nccl begin.")
|
||||
if self.device == "cuda":
|
||||
torch.cuda.set_device(self.gpu_id)
|
||||
backend = "nccl"
|
||||
|
||||
if not self.server_args.enable_p2p_check:
|
||||
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
||||
|
||||
if self.server_args.dist_init_addr:
|
||||
nccl_init_method = f"tcp://{self.server_args.dist_init_addr}"
|
||||
dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
|
||||
else:
|
||||
nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
|
||||
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
|
||||
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
||||
init_distributed_environment(
|
||||
backend="nccl",
|
||||
backend=backend,
|
||||
world_size=self.tp_size,
|
||||
rank=self.tp_rank,
|
||||
local_rank=self.gpu_id,
|
||||
distributed_init_method=nccl_init_method,
|
||||
distributed_init_method=dist_init_method,
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
||||
min_per_gpu_memory = get_available_gpu_memory(
|
||||
self.gpu_id, distributed=self.tp_size > 1
|
||||
self.device, self.gpu_id, distributed=self.tp_size > 1
|
||||
)
|
||||
self.tp_group = get_tp_group()
|
||||
|
||||
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
|
||||
# so we disable padding in cuda graph.
|
||||
if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)):
|
||||
if self.device == "cuda" and not all(
|
||||
in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
|
||||
):
|
||||
self.server_args.disable_cuda_graph_padding = True
|
||||
logger.info(
|
||||
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
|
||||
@@ -172,7 +179,7 @@ class ModelRunner:
|
||||
|
||||
# Check memory for tensor parallelism
|
||||
if self.tp_size > 1:
|
||||
local_gpu_memory = get_available_gpu_memory(self.gpu_id)
|
||||
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
if min_per_gpu_memory < local_gpu_memory * 0.9:
|
||||
raise ValueError(
|
||||
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
||||
@@ -182,23 +189,22 @@ class ModelRunner:
|
||||
|
||||
def load_model(self):
|
||||
logger.info(
|
||||
f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
# This can reduce thread conflicts and speed up weight loading.
|
||||
torch.set_num_threads(1)
|
||||
|
||||
if torch.cuda.get_device_capability()[0] < 8:
|
||||
logger.info(
|
||||
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
||||
)
|
||||
self.server_args.dtype = "float16"
|
||||
if torch.cuda.get_device_capability()[1] < 5:
|
||||
raise RuntimeError("SGLang only supports sm75 and above.")
|
||||
if self.device == "cuda":
|
||||
if torch.cuda.get_device_capability()[0] < 8:
|
||||
logger.info(
|
||||
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
||||
)
|
||||
self.server_args.dtype = "float16"
|
||||
if torch.cuda.get_device_capability()[1] < 5:
|
||||
raise RuntimeError("SGLang only supports sm75 and above.")
|
||||
|
||||
# Prepare the vllm model config
|
||||
monkey_patch_vllm_dummy_weight_loader()
|
||||
self.device_config = DeviceConfig()
|
||||
self.load_config = LoadConfig(load_format=self.server_args.load_format)
|
||||
self.vllm_model_config = VllmModelConfig(
|
||||
model=self.server_args.model_path,
|
||||
@@ -220,7 +226,7 @@ class ModelRunner:
|
||||
self.model = get_model(
|
||||
model_config=self.vllm_model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=self.device_config,
|
||||
device_config=DeviceConfig(self.device),
|
||||
parallel_config=None,
|
||||
scheduler_config=None,
|
||||
lora_config=None,
|
||||
@@ -240,7 +246,7 @@ class ModelRunner:
|
||||
f"Load weight end. "
|
||||
f"type={type(self.model).__name__}, "
|
||||
f"dtype={self.dtype}, "
|
||||
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
def update_weights(self, model_path: str, load_format: str):
|
||||
@@ -254,10 +260,10 @@ class ModelRunner:
|
||||
|
||||
logger.info(
|
||||
f"Update weights begin. "
|
||||
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
target_device = torch.device(self.device_config.device)
|
||||
target_device = torch.device(self.device)
|
||||
|
||||
try:
|
||||
# TODO: Use a better method to check this
|
||||
@@ -343,7 +349,7 @@ class ModelRunner:
|
||||
|
||||
def profile_max_num_token(self, total_gpu_memory: int):
|
||||
available_gpu_memory = get_available_gpu_memory(
|
||||
self.gpu_id, distributed=self.tp_size > 1
|
||||
self.device, self.gpu_id, distributed=self.tp_size > 1
|
||||
)
|
||||
if (
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
@@ -409,11 +415,10 @@ class ModelRunner:
|
||||
4096,
|
||||
)
|
||||
|
||||
device = "cuda"
|
||||
self.req_to_token_pool = ReqToTokenPool(
|
||||
size=max_num_reqs + 1,
|
||||
max_context_len=self.model_config.context_len + 4,
|
||||
device=device,
|
||||
device=self.device,
|
||||
)
|
||||
if (
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
@@ -425,7 +430,7 @@ class ModelRunner:
|
||||
kv_lora_rank=self.model_config.kv_lora_rank,
|
||||
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
device=device,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool = MHATokenToKVPool(
|
||||
@@ -434,11 +439,11 @@ class ModelRunner:
|
||||
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
||||
head_dim=self.model_config.head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
device=device,
|
||||
device=self.device,
|
||||
)
|
||||
logger.info(
|
||||
f"Memory pool end. "
|
||||
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
def init_cublas(self):
|
||||
|
||||
@@ -37,6 +37,9 @@ class SamplingBatchInfo:
|
||||
linear_penalties: torch.Tensor = None
|
||||
scaling_penalties: torch.Tensor = None
|
||||
|
||||
# Device
|
||||
device: str = "cuda"
|
||||
|
||||
@classmethod
|
||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||
reqs = batch.reqs
|
||||
@@ -62,6 +65,7 @@ class SamplingBatchInfo:
|
||||
min_ps=min_ps,
|
||||
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
||||
vocab_size=vocab_size,
|
||||
device=batch.input_ids.device,
|
||||
)
|
||||
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
||||
|
||||
@@ -75,7 +79,7 @@ class SamplingBatchInfo:
|
||||
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
||||
vocab_size=vocab_size,
|
||||
batch=batch,
|
||||
device="cuda",
|
||||
device=batch.input_ids.device,
|
||||
Penalizers={
|
||||
penaltylib.BatchedFrequencyPenalizer,
|
||||
penaltylib.BatchedMinNewTokensPenalizer,
|
||||
@@ -107,7 +111,7 @@ class SamplingBatchInfo:
|
||||
self.linear_penalties = torch.zeros(
|
||||
(bs, self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
device=self.device,
|
||||
)
|
||||
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
||||
|
||||
@@ -119,7 +123,10 @@ class SamplingBatchInfo:
|
||||
|
||||
if has_regex:
|
||||
self.vocab_mask = torch.zeros(
|
||||
len(self.temperatures), self.vocab_size, dtype=torch.bool, device="cuda"
|
||||
len(self.temperatures),
|
||||
self.vocab_size,
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
)
|
||||
for i, regex_fsm in enumerate(self.regex_fsms):
|
||||
if regex_fsm is not None:
|
||||
@@ -144,7 +151,12 @@ class SamplingBatchInfo:
|
||||
|
||||
@staticmethod
|
||||
def merge_bias_tensor(
|
||||
lhs: torch.Tensor, rhs: torch.Tensor, bs1: int, bs2: int, default: int = 0
|
||||
lhs: torch.Tensor,
|
||||
rhs: torch.Tensor,
|
||||
bs1: int,
|
||||
bs2: int,
|
||||
device: str,
|
||||
default: int = 0,
|
||||
):
|
||||
# bias tensor can be None
|
||||
if lhs is not None or rhs is not None:
|
||||
@@ -155,9 +167,9 @@ class SamplingBatchInfo:
|
||||
shape, dtype = rhs.shape[1:], rhs.dtype
|
||||
with torch.dtype(dtype):
|
||||
if lhs is None:
|
||||
lhs = torch.empty((bs1, *shape), device="cuda").fill_(default)
|
||||
lhs = torch.empty((bs1, *shape), device=device).fill_(default)
|
||||
if rhs is None:
|
||||
rhs = torch.empty((bs2, *shape), device="cuda").fill_(default)
|
||||
rhs = torch.empty((bs2, *shape), device=device).fill_(default)
|
||||
return torch.cat([lhs, rhs])
|
||||
|
||||
return None
|
||||
@@ -176,5 +188,5 @@ class SamplingBatchInfo:
|
||||
setattr(self, item, torch.concat([self_val, other_val]))
|
||||
|
||||
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
||||
self.logit_bias, other.logit_bias, len(self), len(other)
|
||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
||||
)
|
||||
|
||||
@@ -36,6 +36,7 @@ class ServerArgs:
|
||||
skip_tokenizer_init: bool = False
|
||||
load_format: str = "auto"
|
||||
dtype: str = "auto"
|
||||
device: str = "cuda"
|
||||
kv_cache_dtype: str = "auto"
|
||||
trust_remote_code: bool = True
|
||||
context_length: Optional[int] = None
|
||||
@@ -237,6 +238,13 @@ class ServerArgs:
|
||||
'* "float" is shorthand for FP32 precision.\n'
|
||||
'* "float32" for FP32 precision.',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
choices=["cuda"],
|
||||
help="The device type.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
|
||||
@@ -140,26 +140,41 @@ def calculate_time(show=False, min_cost_ms=0.0):
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_available_gpu_memory(gpu_id, distributed=False):
|
||||
def get_available_gpu_memory(device, gpu_id, distributed=False):
|
||||
"""
|
||||
Get available memory for cuda:gpu_id device.
|
||||
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
||||
"""
|
||||
num_gpus = torch.cuda.device_count()
|
||||
assert gpu_id < num_gpus
|
||||
if device == "cuda":
|
||||
num_gpus = torch.cuda.device_count()
|
||||
assert gpu_id < num_gpus
|
||||
|
||||
if torch.cuda.current_device() != gpu_id:
|
||||
print(
|
||||
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
|
||||
"which may cause useless memory allocation for torch CUDA context.",
|
||||
)
|
||||
if torch.cuda.current_device() != gpu_id:
|
||||
print(
|
||||
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
|
||||
"which may cause useless memory allocation for torch CUDA context.",
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
|
||||
torch.cuda.empty_cache()
|
||||
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
|
||||
|
||||
elif device == "xpu":
|
||||
num_gpus = torch.xpu.device_count()
|
||||
assert gpu_id < num_gpus
|
||||
|
||||
if torch.xpu.current_device() != gpu_id:
|
||||
print(
|
||||
f"WARNING: current device is not {gpu_id}, but {torch.xpu.current_device()}, ",
|
||||
"which may cause useless memory allocation for torch XPU context.",
|
||||
)
|
||||
torch.xpu.empty_cache()
|
||||
used_memory = torch.xpu.memory_allocated()
|
||||
total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory
|
||||
free_gpu_memory = total_gpu_memory - used_memory
|
||||
|
||||
if distributed:
|
||||
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
|
||||
torch.device("cuda", gpu_id)
|
||||
torch.device(device, gpu_id)
|
||||
)
|
||||
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
|
||||
free_gpu_memory = tensor.item()
|
||||
|
||||
Reference in New Issue
Block a user