init
This commit is contained in:
66
vllm/model_executor/neuron_model_loader.py
Normal file
66
vllm/model_executor/neuron_model_loader.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Utilities for selecting and loading models."""
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import ModelConfig, DeviceConfig
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
|
||||
TORCH_DTYPE_TO_NEURON_AMP = {
|
||||
"auto": "f32",
|
||||
"half": "f16",
|
||||
"float16": "f16",
|
||||
"bfloat16": "bf16",
|
||||
"float": "f32",
|
||||
"float32": "f32",
|
||||
torch.float16: "f16",
|
||||
torch.bfloat16: "bf16",
|
||||
torch.float32: "f32",
|
||||
}
|
||||
|
||||
|
||||
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
model_cls = ModelRegistry.load_model_cls(arch)
|
||||
if model_cls is not None:
|
||||
return model_cls
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported for now. "
|
||||
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
||||
|
||||
|
||||
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
|
||||
**kwargs) -> nn.Module:
|
||||
from transformers_neuronx.config import NeuronConfig, ContinuousBatchingConfig
|
||||
|
||||
parallel_config = kwargs.get("parallel_config")
|
||||
scheduler_config = kwargs.get("scheduler_config")
|
||||
|
||||
model_class = _get_model_architecture(model_config.hf_config)
|
||||
linear_method = None
|
||||
|
||||
# Create a model instance.
|
||||
model = model_class(model_config.hf_config, linear_method)
|
||||
|
||||
continuous_batching_config = ContinuousBatchingConfig(
|
||||
batch_size_for_shared_caches=scheduler_config.max_num_seqs)
|
||||
neuron_config = NeuronConfig(
|
||||
continuous_batching=continuous_batching_config)
|
||||
|
||||
# Load the weights from the cached or downloaded files.
|
||||
model.load_weights(
|
||||
model_config.model,
|
||||
model_config.download_dir,
|
||||
model_config.load_format,
|
||||
model_config.revision,
|
||||
tp_degree=parallel_config.neuron_tp_degree,
|
||||
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
|
||||
neuron_config=neuron_config,
|
||||
context_length_estimate=[scheduler_config.max_model_len],
|
||||
n_positions=[scheduler_config.max_model_len],
|
||||
batch_size=scheduler_config.max_num_seqs)
|
||||
|
||||
return model.eval()
|
||||
Reference in New Issue
Block a user