feat: support torchair graph mode in v1 engine (#789)
### What this PR does / why we need it? support torchair graph mode with v1 engine --------- Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
@@ -63,6 +63,8 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||
|
||||
import vllm.envs as envs
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphCaptureContext:
|
||||
@@ -117,6 +119,12 @@ class NPUModelRunner:
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
|
||||
self.graph_block_tables = np.zeros(
|
||||
(self.vllm_config.scheduler_config.max_num_seqs,
|
||||
(self.model_config.max_model_len + self.block_size - 1) //
|
||||
self.block_size),
|
||||
dtype=np.int32)
|
||||
|
||||
# Model-related.
|
||||
self.num_attn_layers = self.model_config.get_num_layers_by_block_type(
|
||||
vllm_config.parallel_config, LayerBlockType.attention)
|
||||
@@ -307,6 +315,15 @@ class NPUModelRunner:
|
||||
self.attn_mask_len, self.dtype)
|
||||
|
||||
self.sampler = Sampler()
|
||||
self.enable_torchair_graph_mode = False
|
||||
self.use_cached_npu_graph = False
|
||||
additional_config = vllm_config.additional_config
|
||||
if additional_config:
|
||||
self.enable_torchair_graph_mode = additional_config.get(
|
||||
"enable_graph_mode",
|
||||
False) and self.vllm_config.model_config.use_mla
|
||||
self.use_cached_npu_graph = additional_config.get(
|
||||
"use_cached_npu_graph", False)
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
"""Update the cached states and the persistent batch with the scheduler
|
||||
@@ -563,11 +580,19 @@ class NPUModelRunner:
|
||||
self.attn_mask = attn_mask
|
||||
self.attn_state = attn_state # type: ignore
|
||||
|
||||
extra_builder_kwargs = {}
|
||||
|
||||
# Add graph_pad_size here
|
||||
if self.enable_torchair_graph_mode:
|
||||
graph_pad_size = self.scheduler_config.max_num_seqs - len(seq_lens)
|
||||
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
|
||||
|
||||
attn_metadata = self.attn_metadata_builder.build( # type: ignore
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
common_prefix_len=None,
|
||||
**extra_builder_kwargs,
|
||||
)
|
||||
|
||||
# Prepare input_ids
|
||||
@@ -582,15 +607,45 @@ class NPUModelRunner:
|
||||
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
||||
input_ids = self.input_ids[:total_num_scheduled_tokens]
|
||||
|
||||
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
padding = torch.zeros(graph_pad_size,
|
||||
dtype=input_ids.dtype,
|
||||
device=input_ids.device)
|
||||
input_ids = torch.cat([input_ids, padding])
|
||||
positions = torch.cat([positions, padding])
|
||||
|
||||
# Run forward pass
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
assert self.model is not None
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=None,
|
||||
)
|
||||
model_kwargs = {}
|
||||
if self.enable_torchair_graph_mode:
|
||||
model_kwargs["kv_caches"] = self.kv_caches
|
||||
model_kwargs["attn_metadata"] = attn_metadata
|
||||
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
torch._dynamo.mark_static(input_ids)
|
||||
torch._dynamo.mark_static(positions)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.block_table)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
|
||||
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
||||
for kv in self.kv_caches:
|
||||
if isinstance(kv, tuple):
|
||||
torch._dynamo.mark_static(kv[0])
|
||||
torch._dynamo.mark_static(kv[1])
|
||||
hidden_states = self.compile_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=None,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
assert self.model is not None
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=None,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
return hidden_states[sample_indices]
|
||||
|
||||
@@ -879,6 +934,31 @@ class NPUModelRunner:
|
||||
logger.info("Loading model weights took %.4f GB",
|
||||
m.consumed_memory / float(2**30))
|
||||
|
||||
# adapter torch compile with npu_backend
|
||||
if self.enable_torchair_graph_mode:
|
||||
import torchair # type: ignore
|
||||
from torchair import patch_for_hcom # type: ignore
|
||||
|
||||
patch_for_hcom()
|
||||
config = torchair.CompilerConfig()
|
||||
config.experimental_config.frozen_parameter = True
|
||||
config.experimental_config.tiling_schedule_optimize = True
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
if not self.use_cached_npu_graph:
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
self.compile_model = torch.compile(
|
||||
self.model,
|
||||
dynamic=True,
|
||||
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
backend=npu_backend)
|
||||
else:
|
||||
self.compile_model = torchair.inference.cache_compile(
|
||||
self.model.forward,
|
||||
dynamic=True,
|
||||
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
config=config,
|
||||
ge_cache=False)
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize KV cache based on `kv_cache_config`.
|
||||
@@ -909,10 +989,29 @@ class NPUModelRunner:
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
torch_npu.npu_format_cast(kv_caches[layer_name], 2)
|
||||
if self.enable_torchair_graph_mode:
|
||||
layer_kv_cache_nope = torch.zeros(
|
||||
kv_cache_shape[:-1] +
|
||||
(self.model_config.hf_text_config.kv_lora_rank, ),
|
||||
dtype=self.dtype,
|
||||
pin_memory=True,
|
||||
device=self.device)
|
||||
layer_kv_cache_pe = torch.zeros(
|
||||
kv_cache_shape[:-1] +
|
||||
(self.model_config.hf_text_config.qk_rope_head_dim,
|
||||
),
|
||||
dtype=self.dtype,
|
||||
pin_memory=True,
|
||||
device=self.device)
|
||||
kv_caches[layer_name] = (layer_kv_cache_nope,
|
||||
layer_kv_cache_pe)
|
||||
torch_npu.npu_format_cast(kv_caches[layer_name][0], 2)
|
||||
torch_npu.npu_format_cast(kv_caches[layer_name][1], 2)
|
||||
else:
|
||||
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
torch_npu.npu_format_cast(kv_caches[layer_name], 2)
|
||||
else:
|
||||
# TODO: add new branches when introducing more types of
|
||||
# KV cache specs.
|
||||
|
||||
Reference in New Issue
Block a user