Revert "chore: update torch v2.5.1" (#2063)

This commit is contained in:
Lianmin Zheng
2024-11-17 15:29:38 -08:00
committed by GitHub
parent 3b878863f7
commit c1f401fc58
10 changed files with 37 additions and 174 deletions

View File

@@ -410,23 +410,37 @@ def monkey_patch_vllm_dummy_weight_loader():
Monkey patch the dummy weight loader in vllm to call process_weights_after_loading.
"""
from vllm.config import VllmConfig
from vllm.model_executor.model_loader.loader import (
CacheConfig,
DeviceConfig,
DummyModelLoader,
LoRAConfig,
ModelConfig,
ParallelConfig,
SchedulerConfig,
_initialize_model,
initialize_dummy_weights,
nn,
set_default_torch_dtype,
)
def load_model(self, *, vllm_config: VllmConfig) -> nn.Module:
with set_default_torch_dtype(vllm_config.model_config.dtype):
with torch.device(vllm_config.device_config.device):
def load_model(
self,
*,
model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(
vllm_config.model_config,
model_config,
self.load_config,
vllm_config.lora_config,
vllm_config.cache_config,
lora_config,
cache_config,
)
for _, module in model.named_modules():
@@ -498,60 +512,6 @@ def maybe_set_triton_cache_manager() -> None:
os.environ["TRITON_CACHE_MANAGER"] = manager
def monkey_patch_vllm_model_config():
from typing import Dict, Set, Tuple, Union
from transformers import PretrainedConfig
from vllm.config import ModelConfig, TaskOption, _Task
def _resolve_task(
self,
task_option: Union[TaskOption, _Task],
hf_config: PretrainedConfig,
) -> Tuple[Set[_Task], _Task]:
architectures = getattr(hf_config, "architectures", [])
if isinstance(architectures, str):
architectures = [architectures]
non_generation_models = {
"LlamaEmbeddingModel",
"MistralModel",
"LlamaForSequenceClassification",
"LlamaForSequenceClassificationWithNormal_Weights",
"InternLM2ForRewardModel",
}
is_generation = not any(arch in non_generation_models for arch in architectures)
auto_map = getattr(hf_config, "auto_map", {})
has_sequence_classification = any(
"ForSequenceClassification" in v for v in auto_map.values()
)
task_support: Dict[_Task, bool] = {
"generate": is_generation,
"embedding": (not is_generation) or has_sequence_classification,
}
supported_tasks_lst = [
task for task, is_supported in task_support.items() if is_supported
]
supported_tasks = set(supported_tasks_lst)
if task_option not in supported_tasks:
msg = (
f"This model does not support the '{task_option}' task. "
f"Supported tasks: {supported_tasks}"
)
raise ValueError(msg)
selected_task = task_option
return supported_tasks, selected_task
setattr(ModelConfig, "_resolve_task", _resolve_task)
class CustomCacheManager(FileCacheManager):
# Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py
def __init__(self, key, override=False, dump=False):