# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # This file is a part of the vllm-ascend project. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This file extends the functionality of linear operations by encapsulating custom communication groups and forward functions into classes (linear ops). Current class inheritance structure: CustomLinearOp ├── CustomColumnParallelOp │ ├── MLPColumnParallelOp │ ├── SequenceColumnParallelOp └── CustomRowParallelOp │ ├── MLPRowParallelOp │ ├── OProjRowParallelOp │ ├── MatmulAllreduceRowParallelOp │ └── SequenceRowParallelOp └── CustomReplicatedOp How to extend a new linear op? Taking column parallel op as an example: 1. Inherit from CustomColumnParallelOp and create a new class MyColumnParallelOp 2. [Optional] The default communication group is the TP group. If a custom communication group is needed, override the comm_group method 3. Override the apply method according to requirements, which will replace the original linear.forward 4. Add selection logic for MyColumnParallelOp in the get_column_parallel_op method, typically based on prefix and configuration judgments Row parallel op follows a similar approach - inherit from RowColumnParallelOp and register the new class in get_row_parallel_op. """ from typing import Optional, Union import torch import torch.distributed as dist import torch.nn.functional as F import torch_npu from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter from vllm.distributed import (split_tensor_along_last_dim, tensor_model_parallel_all_reduce, tensor_model_parallel_reduce_scatter) from vllm.distributed.parallel_state import get_tp_group from vllm.forward_context import get_forward_context from vllm_npu.distributed.parallel_state import (get_mlp_tp_group, get_otp_group) from vllm_npu.utils import (dense_optim_enable, enable_sp, matmul_allreduce_enable, mlp_tp_enable, oproj_tp_enable, shared_expert_dp_enabled) class CustomLinearOp: def __init__(self, layer): self.layer = layer self.bias = None self.skip_bias_add = None self.return_bias = None self.quant_method = None # Custom communication group, while determining weight sharding @property def comm_group(self): return get_tp_group() @property def tp_rank(self): return self.comm_group.rank_in_group @property def tp_size(self): return self.comm_group.world_size # Update the attributes required by apply(), obtaining them from the layer. # Call this after the layer completes its initialization, specifically at the end of layer.init(). def update_attrs(self): if hasattr(self.layer, "bias"): self.bias = self.layer.bias self.skip_bias_add = self.layer.skip_bias_add self.return_bias = self.layer.return_bias self.quant_method = self.layer.quant_method self.prefix = self.layer.prefix def apply_impl(self, input_): raise NotImplementedError # Replace layer.forward to customize the layer computation process. def apply(self, input_): output, output_bias = self.apply_impl(input_) if not self.return_bias: return output return output, output_bias class CustomColumnParallelOp(CustomLinearOp): def __init__(self, layer): super().__init__(layer) self.gather_output = None def update_attrs(self): super().update_attrs() self.gather_output = self.layer.gather_output class CustomRowParallelOp(CustomLinearOp): def __init__(self, layer): super().__init__(layer) self.reduce_results = None self.input_is_parallel = None self.input_size_per_partition = None def update_attrs(self): super().update_attrs() self.input_is_parallel = self.layer.input_is_parallel self.reduce_results = self.layer.reduce_results self.input_size_per_partition = self.layer.input_size_per_partition def apply(self, input_): output, output_bias = self.apply_impl(input_) if dense_optim_enable(): torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix) if not self.return_bias: return output return output, output_bias class CustomReplicatedOp(CustomLinearOp): def apply_impl(self, input_): bias = self.bias if not self.skip_bias_add else None assert self.quant_method is not None output = self.quant_method.apply(self.layer, input_, bias) output_bias = self.bias if self.skip_bias_add else None return output, output_bias class MLPColumnParallelOp(CustomColumnParallelOp): def __init__(self, layer): super().__init__(layer) @property def comm_group(self): return get_mlp_tp_group() def apply_impl( self, input_: torch.Tensor, ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: bias = self.bias if not self.skip_bias_add else None # Matrix multiply. assert self.quant_method is not None input_parallel = self.comm_group.all_gather(input_, 0) output = self.quant_method.apply(self.layer, input_parallel, bias) output_bias = self.bias if self.skip_bias_add else None return output, output_bias class MLPRowParallelOp(CustomRowParallelOp): def __init__(self, layer): super().__init__(layer) @property def comm_group(self): return get_mlp_tp_group() def apply_impl( self, input_: torch.Tensor ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: if self.input_is_parallel: input_parallel = input_ else: splitted_input = split_tensor_along_last_dim( input_, num_partitions=self.tp_size) input_parallel = splitted_input[self.tp_rank].contiguous() assert self.quant_method is not None bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.layer.bias output_parallel = self.quant_method.apply(self.layer, input_parallel, bias=bias_) output = self.comm_group.reduce_scatter(output_parallel, 0) output_bias = self.bias if self.skip_bias_add else None return output, output_bias class OProjRowParallelOp(CustomRowParallelOp): def __init__(self, layer): super().__init__(layer) @property def comm_group(self): return get_otp_group() def apply_impl( self, input_: torch.Tensor, ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: if self.input_is_parallel: input_parallel = input_ else: splitted_input = split_tensor_along_last_dim( input_, num_partitions=self.tp_size) input_parallel = splitted_input[self.tp_rank].contiguous() # Prepare tensors for all-to-all communication local_batch_size = input_parallel.size(0) chunk_size = self.input_size_per_partition total_batch_size = local_batch_size * self.tp_size # Reshape tensor for efficient cross-device transfer: # [batch, dim] -> [tp_size, batch, chunk] -> flattened send_buf = (input_parallel.reshape(-1, self.tp_size, chunk_size).transpose( 0, 1).contiguous().view(-1)) # Create receive buffer recv_buf = torch.empty(total_batch_size * chunk_size, dtype=input_parallel.dtype, device=input_parallel.device) # Perform all-to-all communication dist.all_to_all_single(recv_buf, send_buf, group=self.comm_group.device_group) input_parallel = recv_buf.view(total_batch_size, chunk_size) # Only fuse bias add for rank 0 to avoid duplicate bias addition in TP>1 bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias assert self.quant_method is not None output_parallel = self.quant_method.apply(self.layer, input_parallel, bias=bias_) # otp-specific: Combine partial results across devices output = self.comm_group.reduce_scatter(output_parallel, dim=0) output = output.view(input_.shape[0], self.layer.output_size) # Handle bias return based on configuration output_bias = self.bias if self.skip_bias_add else None return output, output_bias def update_attrs(self): super().update_attrs() self.input_is_parallel = self.layer.input_is_parallel self.input_size_per_partition = self.layer.input_size_per_partition class MatmulAllreduceRowParallelOp(CustomRowParallelOp): _HCOMM_INFO = None def __init__(self, layer): super().__init__(layer) self.hcomm_info = self.get_hcomm_info(self.comm_group.device_group) def apply_impl( self, input_: torch.Tensor ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: if self.input_is_parallel: input_parallel = input_ else: splitted_input = split_tensor_along_last_dim( input_, num_partitions=self.tp_size) input_parallel = splitted_input[self.tp_rank].contiguous() """Calculate the output tensor of forward by considering fusing communication and computation.""" bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias if self.reduce_results and self.tp_size > 1: output = torch_npu.npu_mm_all_reduce_base(input_parallel, self.weight_t, self.hcomm_info, bias=bias_) else: assert self.quant_method is not None output = self.quant_method.apply(self.layer, input_parallel, bias=bias_) output_bias = self.bias if self.skip_bias_add else None return output, output_bias @classmethod def get_hcomm_info(cls, group: ProcessGroup) -> str: """Get the HCCL communication information for the given group.""" if cls._HCOMM_INFO is not None: return cls._HCOMM_INFO rank = torch.distributed.get_rank(group) if torch.__version__ > "2.0": global_rank = torch.distributed.get_global_rank(group, rank) cls._HCOMM_INFO = group._get_backend( torch.device("npu")).get_hccl_comm_name(global_rank) else: cls._HCOMM_INFO = group.get_hccl_comm_name(rank) return cls._HCOMM_INFO def update_attrs(self): super().update_attrs() self.weight_t = self.layer.weight.t() class SequenceColumnParallelOp(CustomColumnParallelOp): def apply_impl( self, input_: torch.Tensor ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: """Linear layer with column parallelism. Implemented multiple optimization projects for dense models, such as FlashComm and communication-computation fusion. """ bias = self.bias if not self.skip_bias_add else None # Matrix multiply. assert self.quant_method is not None input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True) output_parallel = self.quant_method.apply(self.layer, input_, bias) if self.gather_output: # All-gather across the partitions. output = self.comm_group.all_gather(output_parallel) else: output = output_parallel output_bias = self.bias if self.skip_bias_add else None return output, output_bias class SequenceRowParallelOp(CustomRowParallelOp): def apply_impl( self, input_: torch.Tensor ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: """Linear layer with column parallelism. Implemented multiple optimization projects for dense models, such as FlashComm and communication-computation fusion. """ if self.input_is_parallel: input_parallel = input_ else: splitted_input = split_tensor_along_last_dim( input_, num_partitions=self.tp_size) input_parallel = splitted_input[self.tp_rank].contiguous() assert self.quant_method is not None bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias if self.tp_size == 1 or not self.reduce_results: output = self.quant_method.apply(self.layer, input_parallel, bias=bias_) else: output = torch.ops.vllm.matmul_and_reduce(input_parallel, self.prefix) output_bias = self.bias if self.skip_bias_add else None return output, output_bias def matmul_and_reduce(self, input_parallel: torch.Tensor, bias_: Optional[Parameter]) -> torch.Tensor: assert self.quant_method is not None try: forward_context = get_forward_context() sp_enabled = forward_context.sp_enabled mmrs_fusion = forward_context.mmrs_fusion except AssertionError: sp_enabled = False mmrs_fusion = False x = input_parallel if not sp_enabled: output_parallel = self.layer.quant_method.apply(self.layer, x, bias=bias_) return tensor_model_parallel_all_reduce(output_parallel) pad_size = forward_context.pad_size if pad_size > 0: x = F.pad(x, (0, 0, 0, pad_size)) world_size = self.layer.tp_size comm_mode = "aiv" hcom_name = get_tp_group().device_group._get_backend( torch.device('npu')).get_hccl_comm_name(self.layer.tp_rank) from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm_npu.quantization.quant_config import AscendLinearMethod from vllm_npu.quantization.w8a8 import (AscendW8A8LinearMethod, quant_per_tensor) # For unquant if mmrs_fusion and isinstance(self.layer.quant_method, UnquantizedLinearMethod): output = torch_npu.npu_mm_reduce_scatter_base( x, self.layer.weight.t(), hcom_name, world_size, reduce_op="sum", bias=None, comm_turn=0, comm_mode=comm_mode) if bias_ is not None: output.add_(bias_) # For w8a8 quant elif mmrs_fusion and ( isinstance(self.layer.quant_method, AscendLinearMethod) and isinstance(self.layer.quant_method.quant_method, AscendW8A8LinearMethod)): if x.dtype != torch.int8: x_quant = quant_per_tensor( x, self.layer.aclnn_input_scale_reciprocal, self.layer.aclnn_input_offset) else: x_quant = x quant_bias = self.layer.quant_bias deq_scale = self.layer.deq_scale output_dtype = torch.bfloat16 output = torch_npu.npu_mm_reduce_scatter_base( x_quant, self.layer.weight, hcom_name, world_size, reduce_op="sum", bias=None, comm_turn=0, x2_scale=deq_scale, output_dtype=output_dtype, comm_mode=comm_mode) output = torch.add( output, torch.mul(quant_bias, deq_scale).to(self.layer.params_dtype)) else: output_parallel = self.layer.quant_method.apply(self.layer, x, bias=bias_) output = tensor_model_parallel_reduce_scatter(output_parallel, 0) return output def update_attrs(self): super().update_attrs() self.input_is_parallel = self.layer.input_is_parallel self.reduce_results = self.layer.reduce_results def _get_column_parallel_op( prefix, layer ) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp]]: if mlp_tp_enable() and "gate_up_proj" in prefix: return MLPColumnParallelOp(layer) if enable_sp(): if "shared_expert" in prefix: return None if "gate_up_proj" in prefix: return SequenceColumnParallelOp(layer) if "in_proj" in prefix: return SequenceColumnParallelOp(layer) if "qkv_proj" in prefix or "conv1d" in prefix: return SequenceColumnParallelOp(layer) return None def _get_row_parallel_op( prefix, layer ) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp, MatmulAllreduceRowParallelOp, SequenceRowParallelOp]]: if "down_proj" in prefix and mlp_tp_enable(): return MLPRowParallelOp(layer) if "o_proj" in prefix and oproj_tp_enable(): return OProjRowParallelOp(layer) if matmul_allreduce_enable(): return MatmulAllreduceRowParallelOp(layer) if enable_sp(): if "shared_expert" in prefix: return None if "o_proj" in prefix or "out_proj" in prefix or "down_proj" in prefix: return SequenceRowParallelOp(layer) return None def get_parallel_op(disable_tp, prefix, layer, direct): if disable_tp or ("shared_experts" in prefix and shared_expert_dp_enabled()): return None, 0, 1 custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp, MLPRowParallelOp, OProjRowParallelOp, MatmulAllreduceRowParallelOp, SequenceRowParallelOp]] = None if direct == "row": custom_op = _get_row_parallel_op(prefix, layer) if direct == "column": custom_op = _get_column_parallel_op(prefix, layer) if custom_op is not None: return custom_op, custom_op.tp_rank, custom_op.tp_size return None, get_tp_group().rank_in_group, get_tp_group().world_size def get_replicated_op(disable_tp, prefix, layer) -> Optional[Union[CustomReplicatedOp]]: if disable_tp: return None return CustomReplicatedOp(layer)