diff --git a/vllm_npu/__init__.py b/vllm_npu/__init__.py index fa5886f..b03577b 100644 --- a/vllm_npu/__init__.py +++ b/vllm_npu/__init__.py @@ -15,16 +15,15 @@ def register(): from vllm_npu.cuda_compat import _patch_cuda_to_npu _patch_cuda_to_npu() - # Register NPU custom ops with vLLM's CustomOp dispatch so that - # ops like SiluAndMul, RMSNorm, RotaryEmbedding use NPU kernels - # instead of falling back to CUDA (which would produce garbage). - _register_npu_ops() - return "vllm_npu.platform.NPUPlatform" -def _register_npu_ops(): - """Register Ascend NPU op overrides with vLLM's CustomOp system.""" +def register_npu_ops(): + """Register Ascend NPU op overrides with vLLM's CustomOp system. + + Must be called AFTER the platform is established (e.g., during + worker init or check_and_update_config), NOT during register(). + """ from vllm.model_executor.custom_op import CustomOp from vllm_npu.ops.activation import AscendSiluAndMul diff --git a/vllm_npu/platform.py b/vllm_npu/platform.py index e5e128e..0c980df 100644 --- a/vllm_npu/platform.py +++ b/vllm_npu/platform.py @@ -180,6 +180,10 @@ class NPUPlatform(Platform): """Adapt vLLM configuration for NPU hardware.""" from vllm.config import CompilationLevel + # Register NPU custom ops (must happen after platform is detected) + from vllm_npu import register_npu_ops + register_npu_ops() + parallel_config = vllm_config.parallel_config cache_config = vllm_config.cache_config compilation_config = vllm_config.compilation_config