Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -282,61 +282,13 @@ class CompilerManager:
|
||||
maybe_key += f"{compile_range.start}_{compile_range.end}"
|
||||
maybe_key += f"_subgraph_{graph_index}"
|
||||
with self.compile_context(compile_range):
|
||||
# There is a compilation time optimization here.
|
||||
#
|
||||
# If the (input metadata, graph, compiler config) are the same, then
|
||||
# we want to avoid compiling the same artifact again. If we didn't
|
||||
# do this optimization, the backend compilation (InductorAdaptor or
|
||||
# InductorStandaloneAdaptor)
|
||||
# is able to cache hit and produce an artifact faster if it was
|
||||
# already created, but it is still a duplicate artifact that
|
||||
# requires unnecessary things e.g. disk IO.
|
||||
#
|
||||
# The optimization is: If the backend compilation cache hits,
|
||||
# then do an early return from the backend compilation and look up
|
||||
# which of the previous in-memory artifacts we created to reuse.
|
||||
#
|
||||
# We implemented this by monkey-patching torch (torch does not
|
||||
# easily expose the cache_key function), but in the future torch
|
||||
# should expose the cache_key function that we can just call
|
||||
# directly before invoking backend compilation.
|
||||
cache_key = None
|
||||
orig = torch._functorch._aot_autograd.autograd_cache.autograd_cache_key
|
||||
|
||||
def autograd_cache_key(*args, **kwargs):
|
||||
result = orig(*args, **kwargs)
|
||||
if result is None:
|
||||
return None
|
||||
nonlocal cache_key
|
||||
cache_key = result[0]
|
||||
if cache_key in self.loaded_artifacts:
|
||||
raise StopCompiling()
|
||||
return result
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
with (
|
||||
# Graphs that are isometric (different node names but same
|
||||
# structure) should be treated as the same.
|
||||
torch._functorch.config.patch(autograd_cache_normalize_inputs=True),
|
||||
patch(
|
||||
"torch._functorch._aot_autograd.autograd_cache.autograd_cache_key",
|
||||
autograd_cache_key,
|
||||
),
|
||||
):
|
||||
try:
|
||||
compiled_graph, handle = self.compiler.compile(
|
||||
graph,
|
||||
example_inputs,
|
||||
additional_inductor_config,
|
||||
compile_range,
|
||||
maybe_key,
|
||||
)
|
||||
except StopCompiling:
|
||||
assert cache_key is not None
|
||||
return self.loaded_artifacts[cache_key]
|
||||
if cache_key is not None and compiled_graph is not None:
|
||||
self.loaded_artifacts[cache_key] = compiled_graph
|
||||
compiled_graph, handle = self.compiler.compile(
|
||||
graph,
|
||||
example_inputs,
|
||||
additional_inductor_config,
|
||||
compile_range,
|
||||
maybe_key,
|
||||
)
|
||||
|
||||
assert compiled_graph is not None, "Failed to compile the graph"
|
||||
|
||||
@@ -497,7 +449,7 @@ def wrap_with_cudagraph_if_needed(
|
||||
# it from the FULL cudagraph runtime mode, no matter it
|
||||
# is wrapped on a full or piecewise fx graph.
|
||||
return static_graph_wrapper_class(
|
||||
runnable=piecewise_backend,
|
||||
runnable=piecewise_backend.graph.forward,
|
||||
vllm_config=vllm_config,
|
||||
runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
cudagraph_options=CUDAGraphOptions(
|
||||
@@ -780,7 +732,7 @@ class VllmBackend:
|
||||
return standalone_compile_artifacts, sym_shape_indices_map, returns_tuple_map
|
||||
|
||||
def configure_post_pass(self) -> None:
|
||||
# self.pass_manager.configure(self.vllm_config)
|
||||
self.pass_manager.configure(self.vllm_config)
|
||||
|
||||
# Post-grad custom passes are run using the post_grad_custom_post_pass
|
||||
# hook. If a pass for that hook exists, add it to the pass manager.
|
||||
@@ -846,7 +798,7 @@ class VllmBackend:
|
||||
),
|
||||
)
|
||||
|
||||
def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any:
|
||||
def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any], **kwargs) -> Any:
|
||||
from .caching import (
|
||||
VllmSerializableFunction,
|
||||
)
|
||||
@@ -988,7 +940,7 @@ class VllmBackend:
|
||||
assert not self._called, "VllmBackend can only be called once"
|
||||
|
||||
self.graph = graph
|
||||
self.configure_post_pass()
|
||||
# self.configure_post_pass()
|
||||
|
||||
if self.compilation_config.use_inductor_graph_partition:
|
||||
# Let Inductor decide partitioning; avoid FX-level pre-splitting.
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
@@ -144,6 +145,18 @@ class StandaloneCompiledArtifacts:
|
||||
self.loaded_submodule_store = {}
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_pytree_map_over_slice():
|
||||
pytree._private_register_pytree_node(
|
||||
slice, lambda x: ([x.start, x.stop, x.step], None), lambda x, c: slice(*x)
|
||||
)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
pytree._deregister_pytree_node(slice)
|
||||
|
||||
|
||||
class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
"""
|
||||
A wrapper around a compiled function by vllm. It will forward the tensor
|
||||
@@ -235,7 +248,10 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
lambda inp: torch.empty_like(inp, device="meta"),
|
||||
state["example_inputs"],
|
||||
)
|
||||
with patch.object(GraphPickler, "reducer_override", _graph_reducer_override):
|
||||
with (
|
||||
patch.object(GraphPickler, "reducer_override", _graph_reducer_override),
|
||||
patch_pytree_map_over_slice(),
|
||||
):
|
||||
state["graph_module"] = GraphPickler.dumps(
|
||||
state["graph_module"], Options(ops_filter=None)
|
||||
)
|
||||
@@ -261,7 +277,8 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
|
||||
state = pickle.loads(data)
|
||||
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
||||
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
|
||||
with patch_pytree_map_over_slice():
|
||||
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
|
||||
state["graph_module"].recompile()
|
||||
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
|
||||
|
||||
|
||||
@@ -184,6 +184,47 @@ def is_compile_cache_enabled(
|
||||
)
|
||||
|
||||
|
||||
def _patch_standalone_compile_atomic_save() -> None:
|
||||
"""Backport of pytorch/pytorch#162432 for torch < 2.10.0.
|
||||
|
||||
Patches CompiledArtifact.save() to use write_atomic for binary format,
|
||||
preventing corrupt cache files when multiple processes compile
|
||||
concurrently.
|
||||
"""
|
||||
from torch._inductor.codecache import write_atomic
|
||||
from torch._inductor.standalone_compile import CompiledArtifact as cls
|
||||
|
||||
if getattr(cls.save, "_vllm_patched", False):
|
||||
return
|
||||
|
||||
original_save = cls.save
|
||||
|
||||
def _save(
|
||||
self: Any, *, path: str, format: Literal["binary", "unpacked"] = "binary"
|
||||
) -> None:
|
||||
if format != "binary":
|
||||
return original_save(self, path=path, format=format)
|
||||
from torch._dynamo.utils import dynamo_timed
|
||||
from torch._inductor.codecache import torch_key
|
||||
from torch.utils._appending_byte_serializer import BytesWriter
|
||||
|
||||
with dynamo_timed("CompiledArtifact.save"):
|
||||
assert self._artifacts is not None
|
||||
artifact_bytes, cache_info = self._artifacts
|
||||
assert len(cache_info.aot_autograd_artifacts) == 1, cache_info
|
||||
key = cache_info.aot_autograd_artifacts[0]
|
||||
assert not os.path.isdir(path)
|
||||
writer = BytesWriter()
|
||||
writer.write_bytes(torch_key())
|
||||
writer.write_str(key)
|
||||
writer.write_bytes(artifact_bytes)
|
||||
write_atomic(path, writer.to_bytes())
|
||||
|
||||
_save._vllm_patched = True # type: ignore[attr-defined]
|
||||
cls.save = _save # type: ignore[assignment]
|
||||
logger.debug("Patched %s.save for atomic writes (torch < 2.10)", cls.__name__)
|
||||
|
||||
|
||||
class InductorStandaloneAdaptor(CompilerInterface):
|
||||
"""
|
||||
The adaptor for the Inductor compiler.
|
||||
@@ -197,6 +238,8 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
name = "inductor_standalone"
|
||||
|
||||
def __init__(self, save_format: Literal["binary", "unpacked"]) -> None:
|
||||
if not is_torch_equal_or_newer("2.10.0"):
|
||||
_patch_standalone_compile_atomic_save()
|
||||
self.save_format = save_format
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
@@ -224,7 +267,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
set_inductor_config(current_config, compile_range)
|
||||
set_functorch_config()
|
||||
# set_functorch_config()
|
||||
|
||||
if compile_range.is_single_size():
|
||||
dynamic_shapes = "from_example_inputs"
|
||||
@@ -325,6 +368,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
|
||||
path=path, format=self.save_format
|
||||
)
|
||||
compilation_counter.num_compiled_artifacts_loaded += 1
|
||||
from torch._inductor.compile_fx import graph_returns_tuple
|
||||
|
||||
returns_tuple = graph_returns_tuple(graph)
|
||||
@@ -395,7 +439,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
current_config["fx_graph_remote_cache"] = False
|
||||
|
||||
set_inductor_config(current_config, compile_range)
|
||||
set_functorch_config()
|
||||
# set_functorch_config()
|
||||
|
||||
# inductor can inplace modify the graph, so we need to copy it
|
||||
# see https://github.com/pytorch/pytorch/issues/138980
|
||||
|
||||
@@ -29,6 +29,8 @@ class CompilationCounter:
|
||||
num_cache_entries_updated: int = 0
|
||||
# The number of standalone_compile compiled artifacts saved
|
||||
num_compiled_artifacts_saved: int = 0
|
||||
# The number of standalone_compile compiled artifacts loaded from cache
|
||||
num_compiled_artifacts_loaded: int = 0
|
||||
# Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE
|
||||
stock_torch_compile_count: int = 0
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from vllm.model_executor.offloader.base import get_offloader
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import current_stream, weak_ref_tensors
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -204,14 +205,14 @@ class CUDAGraphWrapper:
|
||||
def unwrap(self) -> Callable[..., Any]:
|
||||
# in case we need to access the original runnable.
|
||||
return self.runnable
|
||||
|
||||
|
||||
def weak_ref_tensors_with_intermediate(self, output):
|
||||
if isinstance(output, IntermediateTensors):
|
||||
intermediate_states = IntermediateTensors(
|
||||
tensors={key: weak_ref_tensors(value) for key, value in output.tensors.items()})
|
||||
return intermediate_states
|
||||
return weak_ref_tensors(output)
|
||||
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any | None:
|
||||
forward_context = get_forward_context()
|
||||
batch_descriptor = forward_context.batch_descriptor
|
||||
@@ -298,12 +299,10 @@ class CUDAGraphWrapper:
|
||||
# the last graph in piecewise cuadgraph mode, because
|
||||
# the output of the last graph will not be used by
|
||||
# any other cuda graph.
|
||||
# output = weak_ref_tensors(output)
|
||||
output = self.weak_ref_tensors_with_intermediate(output)
|
||||
|
||||
# here we always use weak ref for the output
|
||||
# to save memory
|
||||
# entry.output = weak_ref_tensors(output)
|
||||
entry.output = self.weak_ref_tensors_with_intermediate(output)
|
||||
entry.cudagraph = cudagraph
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ class GEMMReduceScatterPattern(BasePattern):
|
||||
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
|
||||
mul,
|
||||
mm_weight,
|
||||
"avg",
|
||||
"sum",
|
||||
scatter_dim=0,
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
@@ -150,7 +150,7 @@ class ScaledMMReduceScatterPattern(BasePattern):
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
"avg",
|
||||
"sum",
|
||||
scatter_dim, # orig_scatter_dim
|
||||
scatter_dim, # scatter_dim_after_maybe_reshape
|
||||
self.tp.device_group.group_name,
|
||||
@@ -285,7 +285,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern):
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
"avg",
|
||||
"sum",
|
||||
scatter_dim, # orig_scatter_dim
|
||||
scatter_dim, # scatter_dim_after_maybe_reshape
|
||||
self.tp.device_group.group_name,
|
||||
|
||||
@@ -5,7 +5,6 @@ import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch._ops import OpOverload
|
||||
|
||||
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
@@ -15,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
ScaleDesc,
|
||||
kFp8Dynamic128Sym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -312,7 +312,9 @@ class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
logger.debug(
|
||||
"%s Replaced %s patterns", self.__class__.__name__, self.matched_count
|
||||
)
|
||||
|
||||
def uuid(self) -> str:
|
||||
fusion_patterns = [
|
||||
@@ -332,9 +334,11 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
|
||||
|
||||
FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op()
|
||||
|
||||
def __init__(self, quant_op: OpOverload) -> None:
|
||||
def __init__(self) -> None:
|
||||
self.silu_and_mul_matcher = MatcherSiluAndMul()
|
||||
self.quant_op = quant_op
|
||||
self.quant_matcher = MatcherQuantFP8(
|
||||
quant_key=kFp8Dynamic128Sym, match_rocm_aiter=True
|
||||
)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
return [
|
||||
@@ -346,7 +350,7 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
|
||||
input: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
at1 = self.silu_and_mul_matcher(input)
|
||||
at2 = self.quant_op(at1, 128)
|
||||
at2 = self.quant_matcher(at1)
|
||||
return at2[0], at2[1]
|
||||
|
||||
def replacement(
|
||||
@@ -370,11 +374,6 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
||||
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
||||
"""
|
||||
|
||||
AITER_GROUP_FP8_QUANT_OP = rocm_aiter_ops.get_group_quant_op()
|
||||
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
|
||||
|
||||
QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
@@ -383,8 +382,7 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
||||
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
|
||||
)
|
||||
|
||||
for quant_op in self.QUANT_OPS:
|
||||
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
|
||||
AiterSiluMulFp8GroupQuantPattern().register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..utility.noop_elimination import NoOpEliminationPass
|
||||
@@ -215,9 +214,6 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
)
|
||||
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -37,6 +37,14 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
|
||||
self.nodes_to_remove: list[torch.fx.Node] = []
|
||||
count = 0
|
||||
|
||||
rope_targets = [torch.ops._C.rotary_embedding.default]
|
||||
|
||||
if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"):
|
||||
rope_targets.append(
|
||||
torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default
|
||||
)
|
||||
|
||||
for node in graph.nodes:
|
||||
if not is_func(node, auto_functionalized):
|
||||
continue # Avoid deep if-elif nesting
|
||||
@@ -44,7 +52,7 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
kwargs = node.kwargs
|
||||
at_target = node.args[0]
|
||||
|
||||
if at_target == torch.ops._C.rotary_embedding.default:
|
||||
if at_target in rope_targets:
|
||||
query = kwargs["query"]
|
||||
key = kwargs["key"]
|
||||
getitem_nodes = self.getitem_users(node)
|
||||
|
||||
@@ -298,18 +298,18 @@ class PiecewiseBackend:
|
||||
else list(args)
|
||||
)
|
||||
|
||||
with (
|
||||
torch._functorch.config.patch("bundled_autograd_cache", True),
|
||||
):
|
||||
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
|
||||
self.graph,
|
||||
args_list,
|
||||
self.vllm_backend.inductor_config,
|
||||
self.compilation_config,
|
||||
compile_range=range_entry.compile_range,
|
||||
graph_index=self.piecewise_compile_index,
|
||||
num_graphs=self.total_piecewise_compiles,
|
||||
)
|
||||
# with (
|
||||
# torch._functorch.config.patch("bundled_autograd_cache", True),
|
||||
# ):
|
||||
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
|
||||
self.graph,
|
||||
args_list,
|
||||
self.vllm_backend.inductor_config,
|
||||
self.compilation_config,
|
||||
compile_range=range_entry.compile_range,
|
||||
graph_index=self.piecewise_compile_index,
|
||||
num_graphs=self.total_piecewise_compiles,
|
||||
)
|
||||
|
||||
range_entry.compiled = True
|
||||
self.to_be_compiled_ranges.remove(range_entry.compile_range)
|
||||
|
||||
@@ -319,3 +319,52 @@ class TorchCompileWithNoGuardsWrapper:
|
||||
yield
|
||||
finally:
|
||||
self.__class__.forward.__code__ = original
|
||||
|
||||
|
||||
def reset_compile_wrapper(model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Clean up compiled model and captured CUDA graphs for elastic EP.
|
||||
"""
|
||||
if not isinstance(model, TorchCompileWithNoGuardsWrapper) and hasattr(
|
||||
model, "model"
|
||||
):
|
||||
model = model.model
|
||||
if not isinstance(model, TorchCompileWithNoGuardsWrapper):
|
||||
return
|
||||
# model.do_not_compile is set by the @support_torch_compile decorator
|
||||
if hasattr(model, "do_not_compile") and model.do_not_compile:
|
||||
return
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
|
||||
# reset the compilation counter
|
||||
compilation_counter.num_models_seen = 0
|
||||
compilation_counter.num_graphs_seen = 0
|
||||
compilation_counter.num_piecewise_graphs_seen = 0
|
||||
compilation_counter.num_piecewise_capturable_graphs_seen = 0
|
||||
compilation_counter.num_backend_compilations = 0
|
||||
compilation_counter.num_gpu_runner_capture_triggers = 0
|
||||
compilation_counter.num_cudagraph_captured = 0
|
||||
compilation_counter.num_inductor_compiles = 0
|
||||
compilation_counter.num_eager_compiles = 0
|
||||
compilation_counter.num_cache_entries_updated = 0
|
||||
compilation_counter.num_compiled_artifacts_saved = 0
|
||||
compilation_counter.stock_torch_compile_count = 0
|
||||
|
||||
# Clear the AOT compiled function so the model is forced to
|
||||
# recompile on the next call. Without this, decorators.py
|
||||
# __call__ uses the stale aot_compiled_fn whose torchinductor
|
||||
# kernels have old parameters (expert_map size for example)
|
||||
# baked in as compile-time constants.
|
||||
if hasattr(model, "aot_compiled_fn"):
|
||||
model.aot_compiled_fn = None
|
||||
if hasattr(model, "was_aot_compile_fn_loaded_from_disk"):
|
||||
model.was_aot_compile_fn_loaded_from_disk = False
|
||||
|
||||
# Reset the cache_dir so VllmBackend recomputes the hash
|
||||
# (data_parallel_size changed, so the config hash differs).
|
||||
compilation_config = model.vllm_config.compilation_config
|
||||
compilation_config.cache_dir = ""
|
||||
compilation_config.local_cache_dir = ""
|
||||
|
||||
model.__class__.forward.__code__ = model.original_code_object()
|
||||
TorchCompileWithNoGuardsWrapper.__init__(model)
|
||||
|
||||
Reference in New Issue
Block a user