mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
大改
This commit is contained in:
557
vllm_npu/torchair/torchair_model_runner.py
Normal file
557
vllm_npu/torchair/torchair_model_runner.py
Normal file
@@ -0,0 +1,557 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
||||
# isort: skip_file
|
||||
|
||||
import math
|
||||
import types
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import logger
|
||||
|
||||
import vllm_npu.envs as envs_ascend
|
||||
from vllm_npu.ascend_config import get_ascend_config
|
||||
from vllm_npu.platform import NPUPlatform
|
||||
from vllm_npu.torchair.utils import (
|
||||
TORCHAIR_CACHE_DIR, TorchairCommonAttentionMetadata,
|
||||
check_torchair_cache_exist, converting_weight_acl_format,
|
||||
register_torchair_model, torchair_ops_patch,
|
||||
torchair_quant_method_register, write_kv_cache_bytes_to_file)
|
||||
from vllm_npu.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
is_310p, get_ascend_soc_version,
|
||||
AscendSocVersion)
|
||||
from vllm_npu.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
|
||||
class NPUTorchairModelRunner(NPUModelRunner):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||
self.ascend_config = get_ascend_config()
|
||||
self.enable_shared_expert_dp = self.ascend_config.enable_shared_expert_dp
|
||||
super().__init__(vllm_config, device)
|
||||
if self.speculative_config:
|
||||
self.actual_seq_lengths_q = list(
|
||||
range(self.decode_token_per_req, self.max_num_tokens + 1,
|
||||
self.decode_token_per_req))
|
||||
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
||||
None, None, vllm_config, device)
|
||||
|
||||
register_torchair_model()
|
||||
torchair_ops_patch()
|
||||
torchair_quant_method_register()
|
||||
if self.enable_shared_expert_dp:
|
||||
return
|
||||
self.new_kv_cache_bytes = -1
|
||||
self.torchair_compiled_model = None # type: ignore
|
||||
self.torchair_compiled_models = {} # type: ignore
|
||||
self.use_cached_npu_graph = self.ascend_config.torchair_graph_config.use_cached_graph
|
||||
self.use_cached_kv_cache_bytes = self.ascend_config.torchair_graph_config.use_cached_kv_cache_bytes
|
||||
self.torchair_graph_batch_sizes = self.ascend_config.torchair_graph_config.graph_batch_sizes
|
||||
if self.ascend_config.torchair_graph_config.graph_batch_sizes_init:
|
||||
self.init_torchair_graph_batch_sizes()
|
||||
|
||||
self.update_torchair_graph_batch_sizes()
|
||||
|
||||
torch._dynamo.cache_size.config.cache_size_limit += len(
|
||||
self.torchair_graph_batch_sizes)
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
torch._logging.set_logs(
|
||||
recompiles=envs_ascend.vllm_npu_TRACE_RECOMPILES)
|
||||
|
||||
self._check_batch_sizes_consistency()
|
||||
|
||||
def _may_pad_kv_consumer_num_seq(self):
|
||||
# pd disaggregation scenario need redundant_batch_sizes to avoid each batch's seq_len exceed 16 tokens
|
||||
# self.max_num_reqs here is greater than the actual maximum request number
|
||||
if self.decode_token_per_req > 1 and self.is_kv_consumer:
|
||||
# applied only when speculative decoding is active
|
||||
FIA_SEQ_LEN_LIMIT = 16
|
||||
new_max_num_reqs = self.max_num_reqs + math.ceil(
|
||||
self.max_num_reqs / FIA_SEQ_LEN_LIMIT) + math.ceil(
|
||||
(self.max_num_reqs * self.decode_token_per_req) /
|
||||
(FIA_SEQ_LEN_LIMIT**2))
|
||||
if self.max_num_reqs < new_max_num_reqs:
|
||||
logger.warning(
|
||||
f"max_num_reqs is updated to {new_max_num_reqs}")
|
||||
self.max_num_reqs = new_max_num_reqs
|
||||
|
||||
def _init_mc2_tokens_capacity(self):
|
||||
# NOTE: To be clear, we need to make sure that during graph capture, the number of
|
||||
# tokens is less than or equal to mc2_tokens_capacity. According to _set_cudagraph_sizes,
|
||||
# the max number of tokens in graph is min(max_num_seqs * uniform_decode_query_len, 512).
|
||||
max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len
|
||||
tp_size = self.parallel_config.tensor_parallel_size
|
||||
max_graph_batch_size = self.calculate_new_torchair_graph_batch_size(
|
||||
max_num_tokens, tp_size)
|
||||
self.mc2_tokens_capacity = max_graph_batch_size
|
||||
|
||||
if get_ascend_soc_version(
|
||||
) == AscendSocVersion.A3 and self.mc2_tokens_capacity > 512:
|
||||
logger.error(
|
||||
f"A3: the max number of tokens must smaller then 512, but now is {self.mc2_tokens_capacity}"
|
||||
)
|
||||
if get_ascend_soc_version(
|
||||
) == AscendSocVersion.A2 and self.mc2_tokens_capacity > 256:
|
||||
logger.error(
|
||||
f"A2: the max number of tokens must smaller then 256, but now is {self.mc2_tokens_capacity}"
|
||||
)
|
||||
|
||||
def _sync_metadata_across_dp(
|
||||
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
|
||||
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
|
||||
"""Override from NPUModelRunner to pad num_tokens"""
|
||||
if self.enable_shared_expert_dp:
|
||||
# Padding is not required for shared_expert_dp cases in eager mode.
|
||||
return num_tokens, None, with_prefill, enable_dbo
|
||||
if self.dp_size == 1:
|
||||
if not with_prefill:
|
||||
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
|
||||
num_tokens)
|
||||
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
|
||||
return num_tokens, None, with_prefill, enable_dbo
|
||||
|
||||
num_tokens_across_dp = torch.zeros(self.dp_size + 2,
|
||||
dtype=torch.int32,
|
||||
device="npu")
|
||||
num_tokens_across_dp[self.dp_rank] = num_tokens
|
||||
num_tokens_across_dp[-2] = int(with_prefill)
|
||||
num_tokens_across_dp[-1] = int(not enable_dbo)
|
||||
dist.all_reduce(num_tokens_across_dp,
|
||||
group=get_dp_group().device_group)
|
||||
with_prefill = bool(num_tokens_across_dp[-2])
|
||||
enable_dbo = not bool(num_tokens_across_dp[-1])
|
||||
num_tokens_across_dp = num_tokens_across_dp[:-2]
|
||||
|
||||
if not with_prefill:
|
||||
max_num_token = num_tokens_across_dp.max().item()
|
||||
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
|
||||
max_num_token)
|
||||
num_tokens_across_dp = torch.full((self.dp_size, ),
|
||||
maybe_padded_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device="npu")
|
||||
else:
|
||||
maybe_padded_num_tokens = num_tokens
|
||||
|
||||
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo
|
||||
|
||||
def _build_dummy_attn_metadata(
|
||||
self,
|
||||
with_prefill: bool,
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
max_query_len: int,
|
||||
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
|
||||
force_attention: bool = False,
|
||||
) -> Optional[dict[str, Any]]:
|
||||
# NOTE: If torchair graph mode and not with_prefill,
|
||||
# we can't skip_attn, it will cause graph recompile.
|
||||
if with_prefill or self.enable_shared_expert_dp:
|
||||
attn_metadata = super()._build_dummy_attn_metadata(
|
||||
with_prefill, num_reqs, num_tokens, max_query_len,
|
||||
aclgraph_runtime_mode, force_attention)
|
||||
else:
|
||||
common_attn_metadata = TorchairCommonAttentionMetadata(
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=1,
|
||||
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
||||
attn_mask=self.attn_mask,
|
||||
spec_attn_mask=self.spec_attn_mask,
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
)
|
||||
attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(
|
||||
common_attn_metadata)
|
||||
return attn_metadata
|
||||
|
||||
def _generate_dummy_run_hidden_states(self, with_prefill,
|
||||
is_torchair_compile, input_ids,
|
||||
positions, attn_metadata, num_tokens,
|
||||
intermediate_tensors, inputs_embeds):
|
||||
if with_prefill or self.enable_shared_expert_dp:
|
||||
if is_310p():
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND)
|
||||
hidden_states = super()._generate_dummy_run_hidden_states(
|
||||
with_prefill, is_torchair_compile, input_ids, positions,
|
||||
attn_metadata, num_tokens, intermediate_tensors, inputs_embeds)
|
||||
else:
|
||||
# Only mark static while compiling
|
||||
if is_torchair_compile:
|
||||
torch._dynamo.mark_static(input_ids)
|
||||
torch._dynamo.mark_static(positions)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.block_table)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
|
||||
torch._dynamo.mark_static(get_forward_context().mc2_mask)
|
||||
if hasattr(attn_metadata.decode, "sin"):
|
||||
torch._dynamo.mark_static(attn_metadata.decode.sin)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.cos)
|
||||
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
||||
if self.speculative_config:
|
||||
torch._dynamo.mark_static(attn_metadata.decode.attn_mask)
|
||||
for kv in self.kv_caches:
|
||||
assert isinstance(kv, tuple), "kv_cache must be a tuple"
|
||||
torch._dynamo.mark_static(kv[0])
|
||||
torch._dynamo.mark_static(kv[1])
|
||||
if is_310p():
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
compiled_model = self._get_torchair_lazy_compiled_model(num_tokens)
|
||||
model_kwargs = {}
|
||||
model_kwargs["kv_caches"] = self.kv_caches
|
||||
model_kwargs["attn_metadata"] = attn_metadata
|
||||
hidden_states = compiled_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=None,
|
||||
**model_kwargs,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def _convert_torch_format(self, kv_cache):
|
||||
if self.enable_shared_expert_dp:
|
||||
return super()._convert_torch_format(kv_cache)
|
||||
kv_cache = torch_npu.npu_format_cast(kv_cache, ACL_FORMAT_FRACTAL_ND)
|
||||
return kv_cache
|
||||
|
||||
def _compile_torchair_graph(self, torchair_graph_batch_sizes) -> None:
|
||||
# Trigger torchair graph capture for specific shapes.
|
||||
# Capture the large shapes first so that the smaller shapes
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)):
|
||||
for _ in range(self.vllm_config.compilation_config.
|
||||
cudagraph_num_of_warmups):
|
||||
self._dummy_run(num_tokens, is_torchair_compile=True)
|
||||
self._dummy_run(num_tokens, is_torchair_compile=True)
|
||||
logger.info("Batchsize %d is compiled successfully: %d/%d.",
|
||||
num_tokens, idx + 1, len(torchair_graph_batch_sizes))
|
||||
|
||||
def _capture_model(self):
|
||||
"""Override from NPUModelRunner to use torchair graph capture."""
|
||||
if self.enable_shared_expert_dp:
|
||||
return super()._capture_model()
|
||||
# TODO(NeverRaR): Calling graph_capture(device=self.device) in
|
||||
# torchair graph capture can cause some issues, so now we just
|
||||
# temporarily split the codepath for the two different graph patterns.
|
||||
torchair_graph_batch_sizes = self.torchair_graph_batch_sizes
|
||||
graph_num = len(torchair_graph_batch_sizes)
|
||||
|
||||
if self.use_cached_npu_graph and not check_torchair_cache_exist():
|
||||
# If caching is enabled but does not exist (either
|
||||
# use_cached_kv_cache_bytes is disabled or kv_cache_bytes are
|
||||
# different), we will compile the model twice. The first time is
|
||||
# used to generate the cache, and the second time is used to load the
|
||||
# cache to skip the overhead caused by Dynamo guard mechanism.
|
||||
logger.info(
|
||||
"Cache compilation for torchair graph is enabled. Now we compile graph to genetate"
|
||||
" torchair cache, this usually takes %.1f~%.1f mins.",
|
||||
0.5 * graph_num, 1.5 * graph_num)
|
||||
self._compile_torchair_graph(torchair_graph_batch_sizes)
|
||||
NPUPlatform.synchronize()
|
||||
# Note: We reset dynamo and reload the compiled torchair cached computation graph below
|
||||
# that was compiled above. This operation reduces graph launch time by 2-4ms and avoids
|
||||
# runtime errors caused by configuration mismatches in graph mode.
|
||||
torch._dynamo.reset()
|
||||
self.torchair_compiled_models.clear()
|
||||
if self.use_cached_npu_graph:
|
||||
logger.info(
|
||||
"Loading torchair graph cache, this usually takes %.1f~%.1f mins.",
|
||||
0.3 * graph_num, 0.5 * graph_num)
|
||||
self._compile_torchair_graph(torchair_graph_batch_sizes)
|
||||
else:
|
||||
logger.info(
|
||||
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
|
||||
0.5 * graph_num, 1.5 * graph_num)
|
||||
self._compile_torchair_graph(torchair_graph_batch_sizes)
|
||||
|
||||
if self.use_cached_kv_cache_bytes and self.new_kv_cache_bytes > 0:
|
||||
write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
|
||||
self.new_kv_cache_bytes)
|
||||
|
||||
def _use_aclgraph(self) -> bool:
|
||||
if self.enable_shared_expert_dp:
|
||||
return super()._use_aclgraph()
|
||||
return False
|
||||
|
||||
def _check_batch_sizes_consistency(self) -> None:
|
||||
if not dist.is_initialized():
|
||||
return
|
||||
|
||||
local = torch.tensor(self.torchair_graph_batch_sizes,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
gathered_graph_batch_size = local.clone()
|
||||
dist.all_reduce(gathered_graph_batch_size,
|
||||
group=get_dp_group().cpu_group)
|
||||
expected = local * self.dp_size
|
||||
|
||||
if not torch.equal(gathered_graph_batch_size, expected):
|
||||
diff_idxs = (gathered_graph_batch_size != expected).nonzero(
|
||||
as_tuple=False).flatten().tolist()
|
||||
raise AssertionError(
|
||||
f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n"
|
||||
f"Local (rank {self.dp_rank}): {local.tolist()}\n"
|
||||
f"Sum over ranks: {gathered_graph_batch_size.tolist()}\n"
|
||||
f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}"
|
||||
)
|
||||
|
||||
def _update_graph_pad_size(self, with_prefill, graph_pad_size):
|
||||
if with_prefill or self.enable_shared_expert_dp:
|
||||
super()._update_graph_pad_size(with_prefill, graph_pad_size)
|
||||
else:
|
||||
self.graph_pad_size = graph_pad_size
|
||||
|
||||
def _update_input_ids_and_positions(self, input_ids, positions,
|
||||
num_input_tokens, with_prefill,
|
||||
padded_num_tokens_across_dp):
|
||||
"""Override from NPUModelRunner to update input_ids and positions"""
|
||||
input_ids, positions = super()._update_input_ids_and_positions(
|
||||
input_ids, positions, num_input_tokens, with_prefill,
|
||||
padded_num_tokens_across_dp)
|
||||
|
||||
if with_prefill or self.enable_shared_expert_dp:
|
||||
return input_ids, positions
|
||||
else:
|
||||
input_ids = self.input_ids[:padded_num_tokens_across_dp]
|
||||
positions = self.positions[:padded_num_tokens_across_dp]
|
||||
return input_ids, positions
|
||||
|
||||
def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
|
||||
padded_num_tokens_across_dp,
|
||||
input_ids, positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds):
|
||||
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
return super()._generate_process_reqs_hidden_states(
|
||||
attn_metadata, with_prefill, padded_num_tokens_across_dp,
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds)
|
||||
model_kwargs = {
|
||||
"kv_caches": self.kv_caches,
|
||||
"attn_metadata": attn_metadata
|
||||
}
|
||||
if not with_prefill:
|
||||
if is_310p():
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_NZ)
|
||||
compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
padded_num_tokens_across_dp)
|
||||
hidden_states = compiled_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
assert self.model is not None
|
||||
if is_310p():
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**model_kwargs,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def _get_torchair_lazy_compiled_model(self, batch_size: int):
|
||||
if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]:
|
||||
raise ValueError(
|
||||
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}"
|
||||
)
|
||||
|
||||
compiled_model = self.torchair_compiled_models.get(
|
||||
batch_size
|
||||
) if self.use_cached_npu_graph else self.torchair_compiled_model
|
||||
|
||||
if compiled_model:
|
||||
return compiled_model
|
||||
|
||||
import torchair # type: ignore
|
||||
from torchair import patch_for_hcom # type: ignore
|
||||
|
||||
patch_for_hcom()
|
||||
|
||||
if is_310p():
|
||||
# on 300I Duo platform, we need to patch broadcast. however, this patch will be
|
||||
# overwritten by patch_for_hcom in torchair. so we need to re-patch it here.
|
||||
from vllm_npu.patch.platform.patch_distributed import \
|
||||
communication_adaptation_310p
|
||||
communication_adaptation_310p()
|
||||
|
||||
config = torchair.CompilerConfig()
|
||||
if self.ascend_config.torchair_graph_config.mode:
|
||||
config.mode = self.ascend_config.torchair_graph_config.mode
|
||||
config.experimental_config.frozen_parameter = \
|
||||
self.ascend_config.torchair_graph_config.enable_frozen_parameter
|
||||
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
|
||||
# disable it on 300I Duo platform now.
|
||||
config.experimental_config.tiling_schedule_optimize = not is_310p()
|
||||
config.experimental_config.enable_view_optimize = \
|
||||
self.ascend_config.torchair_graph_config.enable_view_optimize
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
if not self.use_cached_npu_graph:
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
self.torchair_compiled_model = torch.compile(
|
||||
self.model,
|
||||
dynamic=not self.use_sparse,
|
||||
fullgraph=True,
|
||||
backend=npu_backend)
|
||||
return self.torchair_compiled_model
|
||||
else:
|
||||
# Generate a new forward proxy code object to prevent the invalidation of
|
||||
# compilation cache caused by dynamo retracing
|
||||
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
|
||||
forward_fn = self.model.forward
|
||||
code = forward_fn.__code__
|
||||
# Mark code object with a new proxy name
|
||||
modified_code = code.replace(co_name=forward_proxy_name, )
|
||||
|
||||
modified_func = types.FunctionType(modified_code,
|
||||
forward_fn.__globals__,
|
||||
name=forward_proxy_name,
|
||||
argdefs=forward_fn.__defaults__)
|
||||
|
||||
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
|
||||
self.model, nn.Module)
|
||||
self.torchair_compiled_models[
|
||||
batch_size] = torchair.inference.cache_compile(
|
||||
self.model.__dict__[forward_proxy_name],
|
||||
dynamic=not self.use_sparse,
|
||||
fullgraph=True,
|
||||
cache_dir=TORCHAIR_CACHE_DIR,
|
||||
config=config,
|
||||
ge_cache=False)
|
||||
return self.torchair_compiled_models[batch_size]
|
||||
|
||||
def init_torchair_graph_batch_sizes(self):
|
||||
start_graph_batch_size = 4
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
# NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks
|
||||
start_graph_batch_size = max(start_graph_batch_size, tp_size)
|
||||
|
||||
while (start_graph_batch_size <= self.max_num_reqs):
|
||||
self.torchair_graph_batch_sizes.append(start_graph_batch_size)
|
||||
start_graph_batch_size *= 2
|
||||
|
||||
def select_torchair_padded_batch_size(self, batch_size: int):
|
||||
for padded_batch_size in self.torchair_graph_batch_sizes:
|
||||
if batch_size <= padded_batch_size:
|
||||
# we treat batch_size as num of requests
|
||||
return padded_batch_size
|
||||
raise ValueError(
|
||||
f"cur batch_size is invalid, torchair_graph_batch_sizes is "
|
||||
f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}."
|
||||
)
|
||||
|
||||
def calculate_new_torchair_graph_batch_size(self, old_graph_batch_size,
|
||||
tp_size):
|
||||
cur_graph_batch_size = (old_graph_batch_size + tp_size -
|
||||
1) // tp_size * tp_size
|
||||
# MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size,
|
||||
# Both adapter multi-dp and FIA operator
|
||||
if self.speculative_config is not None and self.speculative_config.num_speculative_tokens > 1:
|
||||
cur_graph_batch_size = (tp_size * old_graph_batch_size) \
|
||||
// math.gcd(tp_size, old_graph_batch_size)
|
||||
return cur_graph_batch_size
|
||||
|
||||
def update_torchair_graph_batch_sizes(self):
|
||||
# return graph_batch_sizes according to the max number of tokens
|
||||
# first pad according to the number of requests
|
||||
if self.is_kv_consumer and self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
|
||||
# pd disaggregation scenario may incorrectly calculate the batch in mtp scenario, so we force set it to max_num_reqs
|
||||
self.torchair_graph_batch_sizes = [self.max_num_reqs]
|
||||
logger.warning(
|
||||
f"is kv_consumer, torch_graph_batch_sizes sets to [max_num_seqs] {[self.max_num_reqs]}"
|
||||
)
|
||||
elif len(self.torchair_graph_batch_sizes) == 0:
|
||||
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
|
||||
else:
|
||||
self.torchair_graph_batch_sizes = sorted(
|
||||
self.torchair_graph_batch_sizes)
|
||||
while self.torchair_graph_batch_sizes[-1] > self.max_num_reqs:
|
||||
self.torchair_graph_batch_sizes.pop()
|
||||
if len(self.torchair_graph_batch_sizes) == 0:
|
||||
logger.warning(
|
||||
"torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]"
|
||||
)
|
||||
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
|
||||
if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs:
|
||||
self.torchair_graph_batch_sizes.append(self.max_num_reqs)
|
||||
|
||||
# padded max number tokens = max_num_req * decode_token_per_req
|
||||
self.torchair_graph_batch_sizes = [
|
||||
graph_batch_size * self.decode_token_per_req
|
||||
for graph_batch_size in self.torchair_graph_batch_sizes
|
||||
]
|
||||
|
||||
# NOTE: when enable_expert_parallel on A3, we need to check if `graph_batch_size` is divisible by `tp_size`
|
||||
# Because we use x_active_mask for dispatch/combine op on A3, which requires that input shape should be same
|
||||
# on all EP ranks
|
||||
if get_ascend_soc_version(
|
||||
) == AscendSocVersion.A3 and self.parallel_config.enable_expert_parallel:
|
||||
self._align_graph_size_divisible_by_tp_size()
|
||||
|
||||
def _align_graph_size_divisible_by_tp_size(self):
|
||||
tp_size = self.parallel_config.tensor_parallel_size
|
||||
new_graph_batch_sizes = []
|
||||
for graph_batch_size in self.torchair_graph_batch_sizes:
|
||||
cur_graph_batch_size = self.calculate_new_torchair_graph_batch_size(
|
||||
graph_batch_size, tp_size)
|
||||
if cur_graph_batch_size not in new_graph_batch_sizes and \
|
||||
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
|
||||
new_graph_batch_sizes.append(cur_graph_batch_size)
|
||||
elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \
|
||||
and self.decode_token_per_req > 1:
|
||||
logger.warning(
|
||||
f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens",
|
||||
f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size."
|
||||
)
|
||||
new_max_num_reqs = math.ceil(
|
||||
max(new_graph_batch_sizes) / self.decode_token_per_req)
|
||||
if self.max_num_reqs != new_max_num_reqs:
|
||||
logger.warning(f"max_num_reqs is updated to {new_max_num_reqs}")
|
||||
self.max_num_reqs = new_max_num_reqs
|
||||
if not (self.decode_token_per_req > 1 and self.is_kv_consumer):
|
||||
# Do not update scheduler_config.max_num_seqs in KV consumer + MTP
|
||||
# Since FIA need extra space for padding
|
||||
# Enforce self.max_num_seqs > self.scheduler_config.max_num_seqs in KV consumer + MTP
|
||||
self.scheduler_config.max_num_seqs = new_max_num_reqs
|
||||
|
||||
if new_graph_batch_sizes != self.torchair_graph_batch_sizes:
|
||||
logger.warning(
|
||||
f"torchair_graph_batch_sizes are updated to {new_graph_batch_sizes}."
|
||||
)
|
||||
self.torchair_graph_batch_sizes = new_graph_batch_sizes
|
||||
|
||||
def _build_drafter_prepare_inputs_torchair_param(self):
|
||||
if self.enable_shared_expert_dp:
|
||||
return super()._build_drafter_prepare_inputs_torchair_param()
|
||||
else:
|
||||
return True
|
||||
Reference in New Issue
Block a user