# adapted from vllm/model_executor/layers/mamba/ops/casual_conv1d.py # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2024, Tri Dao. # Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py # and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py # mypy: ignore-errors from typing import Optional, Union import torch import torch.nn.functional as F import triton import triton.language as tl PAD_SLOT_ID = -1 def causal_conv1d_ref( x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, initial_states: Optional[torch.Tensor] = None, return_final_states: bool = False, final_states_out: Optional[torch.Tensor] = None, activation: Optional[str] = "silu", ): """ x: (batch, dim, seqlen) weight: (dim, width) bias: (dim,) initial_states: (batch, dim, width - 1) final_states_out: (batch, dim, width - 1) out: (batch, dim, seqlen) """ if activation not in [None, "silu", "swish"]: raise NotImplementedError("activation must be None, silu, or swish") dtype_in = x.dtype x = x.to(weight.dtype) seqlen = x.shape[-1] dim, width = weight.shape if initial_states is None: out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) else: x = torch.cat([initial_states, x], dim=-1) out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) out = out[..., :seqlen] if return_final_states: final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( dtype_in) # (batch, dim, width - 1) if final_states_out is not None: final_states_out.copy_(final_states) else: final_states_out = final_states out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) return (out, None) if not return_final_states else (out, final_states_out) def causal_conv1d_fn( x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, query_start_loc: Optional[torch.Tensor] = None, cache_indices: Optional[torch.Tensor] = None, has_initial_state: Optional[torch.Tensor] = None, conv_states: Optional[torch.Tensor] = None, activation: Optional[str] = "silu", pad_slot_id: int = PAD_SLOT_ID, ): """ x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen sequences are concatenated from left to right for varlen weight: (dim, width) bias: (dim,) query_start_loc: (batch + 1) int32 The cumulative sequence lengths of the sequences in the batch, used to index into sequence. prepended by 0. for example: query_start_loc = torch.Tensor([0,10,16,17]), x.shape=(dim,17) cache_indices: (batch) int32 indicates the corresponding state index, like so: conv_state = conv_states[cache_indices[batch_id]] has_initial_state: (batch) bool indicates whether should the kernel take the current state as initial state for the calculations conv_states: (...,dim,width - 1) itype updated inplace if provided activation: either None or "silu" or "swish" pad_slot_id: int if cache_indices is passed, lets the kernel identify padded entries that will not be processed, for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 out: (batch, dim, seqlen) """ if activation not in [None, "silu", "swish"]: raise NotImplementedError("activation must be None, silu, or swish") if x.stride(-1) != 1: x = x.contiguous() bias = bias.contiguous() if bias is not None else None out_ref = [] out_ref_b = [] seqlens = query_start_loc[1:] - query_start_loc[:-1] seqlens = seqlens.tolist() splits = torch.split(x, seqlens, dim=-1) for i in range(len(seqlens)): x_s = splits[i] if cache_indices[i] == PAD_SLOT_ID: continue out_ref_b.append( causal_conv1d_ref( x_s, weight, bias, activation=activation, return_final_states=True, final_states_out=conv_states[cache_indices[i]].unsqueeze(0), initial_states=conv_states[cache_indices[i]] if has_initial_state[i] else None)) out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1)) out_ref_tensor = torch.cat(out_ref, dim=0) return out_ref_tensor @triton.jit() def _causal_conv1d_update_kernel( # Pointers to matrices x_ptr, # (batch, dim, seqlen) w_ptr, # (dim, width) bias_ptr, conv_state_ptr, cache_seqlens_ptr, # circular buffer conv_state_indices_ptr, num_accepted_tokens_ptr, intermediate_conv_window_ptr, o_ptr, # (batch, dim, seqlen) # Matrix dimensions batch: int, dim: tl.constexpr, seqlen: tl.constexpr, state_len: tl.constexpr, num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines # Strides stride_x_seq: tl.constexpr, stride_x_dim: tl.constexpr, stride_x_token: tl.constexpr, stride_w_dim: tl.constexpr, stride_w_width: tl.constexpr, stride_conv_state_seq: tl.constexpr, stride_conv_state_dim: tl.constexpr, stride_conv_state_tok: tl.constexpr, stride_state_indices: tl.constexpr, stride_inter_seq: tl.constexpr, stride_inter_step: tl.constexpr, stride_inter_dim: tl.constexpr, stride_inter_win: tl.constexpr, stride_o_seq: tl.constexpr, stride_o_dim: tl.constexpr, stride_o_token: tl.constexpr, # others pad_slot_id: tl.constexpr, # Meta-parameters HAS_BIAS: tl.constexpr, KERNEL_WIDTH: tl.constexpr, SILU_ACTIVATION: tl.constexpr, IS_CONTINUOUS_BATCHING: tl.constexpr, IS_SPEC_DECODING: tl.constexpr, NP2_STATELEN: tl.constexpr, USE_PAD_SLOT: tl.constexpr, BLOCK_N: tl.constexpr, SAVE_INTERMEDIATE: tl.constexpr, ): # ruff: noqa: E501 idx_seq = tl.program_id(0) if idx_seq >= batch: return # [BLOCK_N,] elements along the feature-dimension (channel) idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) if IS_CONTINUOUS_BATCHING: # mask = idx_seq < batch conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq * stride_state_indices).to( tl.int64) else: conv_state_batch_coord = idx_seq if USE_PAD_SLOT: # noqa if conv_state_batch_coord == pad_slot_id: # not processing as this is not the actual sequence return if IS_SPEC_DECODING: # The rolling of conv state: # # Before forward, the conv_state is: # [history1, history2, ..., historyM]. # # After forward, the conv_state becomes: # [history2, ..., historyM, draft1, draft2, ..., draftN]. # # After acceptance, it becomes: # # - accept 1 tokens: [history2, ..., historyM, draft1] # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] # - and so on. conv_state_token_offset = tl.load(num_accepted_tokens_ptr + idx_seq) - 1 else: conv_state_token_offset = 0 # STEP 1: READ init_state data conv_states_base = (conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim)) mask_w = idx_feats < dim prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok if KERNEL_WIDTH >= 2: conv_states_ptrs = prior_tokens # [BLOCK_N] col0 = tl.load(conv_states_ptrs, mask_w, 0.0) if KERNEL_WIDTH >= 3: conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N] col1 = tl.load(conv_states_ptrs, mask_w, 0.0) if KERNEL_WIDTH >= 4: conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] col2 = tl.load(conv_states_ptrs, mask_w, 0.0) if KERNEL_WIDTH == 5: conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] #col3 = tl.load(conv_states_ptrs, mask_w, 0.0) # STEP 2: assume state_len > seqlen idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] # The conv_state updates works in a sliding window manner, # at each forward pass, the tokens are shift by 1, so we # load since idx_tokens + 1. conv_state_ptrs_source = ( conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + conv_state_token_offset * stride_conv_state_tok + (idx_feats * stride_conv_state_dim)[None, :] + ((idx_tokens + 1) * stride_conv_state_tok)[:, None] ) # [BLOCK_M, BLOCK_N] mask = ((conv_state_batch_coord < num_cache_lines) & ((idx_tokens + seqlen) < state_len)[:, None] & (idx_feats < dim)[None, :]) conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) VAL = state_len - seqlen x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim ) # [BLOCK_N] x_ptrs = (x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] ) # [BLOCK_M, BLOCK_N] mask_x = ((idx_tokens - VAL >= 0)[:, None] & (idx_tokens - VAL < seqlen)[:, None] & (idx_feats < dim)[None, :] ) # token-index # token-index # feature-index loaded_x = tl.load(x_ptrs, mask_x, 0.0) tl.debug_barrier() new_conv_state = tl.where(mask, conv_state, loaded_x) conv_state_base = (conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] conv_state_ptrs_target = (conv_state_base + (idx_tokens * stride_conv_state_tok)[:, None] ) # [BLOCK_M, BLOCK_N] mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] tl.store(conv_state_ptrs_target, new_conv_state, mask) # STEP 3: init accumulator if HAS_BIAS: bias = bias_ptr + idx_feats mask_bias = idx_feats < dim acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(tl.float32) # [BLOCK_N] else: acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32) # STEP 4: # PRE-LOAD WEIGHTS # first kernel column, configured for weights to handle BLOCK_N features in range w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] mask_w = idx_feats < dim if KERNEL_WIDTH >= 2: w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor w_col0 = tl.load(w_ptrs, mask_w, other=0.0) w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor w_col1 = tl.load(w_ptrs, mask_w, other=0.0) if KERNEL_WIDTH >= 3: w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor w_col2 = tl.load(w_ptrs, mask_w, other=0.0) if KERNEL_WIDTH >= 4: w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor w_col3 = tl.load(w_ptrs, mask_w, other=0.0) x_base_1d = x_base # starting of chunk [BLOCK_N] mask_x_1d = idx_feats < dim # STEP 5: compute each token for idx_token in tl.static_range(seqlen): acc = acc_preload matrix_w = w_col0 matrix_x = col0 for j in tl.static_range(KERNEL_WIDTH): if KERNEL_WIDTH == 2: if j == 1: # KERNEL_WIDTH-1: matrix_w = w_col1 x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) elif KERNEL_WIDTH == 3: if j == 1: matrix_w = w_col1 matrix_x = col1 elif j == 2: matrix_w = w_col2 x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) elif KERNEL_WIDTH == 4: if j == 1: matrix_w = w_col1 matrix_x = col1 elif j == 2: matrix_w = w_col2 matrix_x = col2 elif j == 3: matrix_w = w_col3 x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) acc += matrix_x * matrix_w # [BLOCK_N] if KERNEL_WIDTH == 2: col0 = matrix_x elif KERNEL_WIDTH == 3: col0 = col1 col1 = matrix_x elif KERNEL_WIDTH == 4: col0 = col1 col1 = col2 col2 = matrix_x if SILU_ACTIVATION: acc = acc / (1 + tl.exp(-acc)) # mask_1d = (idx_token < seqlen) & ( # idx_feats < dim # ) # token-index # feature-index maskL = idx_feats < dim maskR = tl.full(maskL.shape, False, tl.int1) mask_1d = tl.where(idx_token < seqlen, maskL, maskR) o_ptrs = (o_ptr + (idx_seq) * stride_o_seq + idx_token * stride_o_token + (idx_feats * stride_o_dim)) tl.store(o_ptrs, acc, mask=mask_1d) if SAVE_INTERMEDIATE: # Save the window state after consuming this token # Layout: [seq(cache line), step, dim, win(K-1)] base_ptr = (intermediate_conv_window_ptr + conv_state_batch_coord * stride_inter_seq + idx_token * stride_inter_step + idx_feats * stride_inter_dim) if KERNEL_WIDTH >= 2: tl.store(base_ptr + 0 * stride_inter_win, col0, mask=mask_w) if KERNEL_WIDTH >= 3: tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w) if KERNEL_WIDTH >= 4: tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w) def causal_conv1d_update_npu( x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, activation: Union[bool, str, None] = None, cache_seqlens: Optional[torch.Tensor] = None, conv_state_indices: Optional[torch.Tensor] = None, num_accepted_tokens: Optional[torch.Tensor] = None, intermediate_conv_window: Optional[torch.Tensor] = None, pad_slot_id: int = PAD_SLOT_ID, metadata=None, validate_data=False, ): """ x: (batch, dim) or (batch, dim, seqlen) [shape=2: single token prediction] [shape=3: single or multiple tokens prediction] conv_state: (..., dim, state_len), where state_len >= width - 1 weight: (dim, width) bias: (dim,) cache_seqlens: (batch,), dtype int32. If not None, the conv_state is treated as a circular buffer. The conv_state will be updated by copying x to the conv_state starting at the index @cache_seqlens % state_len. conv_state_indices: (batch,), dtype int32 If not None, the conv_state is a larger tensor along the batch dim, and we are selecting the batch coords specified by conv_state_indices. Useful for a continuous batching scenario. pad_slot_id: int if cache_indices is passed, lets the kernel identify padded entries that will not be processed, for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 out: (batch, dim) or (batch, dim, seqlen) """ if validate_data: assert cache_seqlens is None # not implemented yet - ok for vLLM assert pad_slot_id is not None assert x.stride(1) == 1 if isinstance(activation, bool): activation = "silu" if activation is True else None elif activation is not None: assert activation in ["silu", "swish"] unsqueeze = x.dim() == 2 if unsqueeze: # make it (batch, dim, seqlen) with seqlen == 1 x = x.unsqueeze(-1) batch, dim, seqlen = x.shape _, width = weight.shape # conv_state: (..., dim, state_len), where state_len >= width - 1 num_cache_lines, _, state_len = conv_state.size() if validate_data: assert dim == weight.size(0) assert ( conv_state.stride(-2) == 1 ), f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" assert state_len >= width - 1 # when above happens, we don't shift-left to keep any records in conv_state assert dim == conv_state.size(1) if conv_state_indices is None: assert conv_state.size(0) >= batch else: assert (batch, ) == conv_state_indices.shape assert num_cache_lines >= batch assert weight.stride(1) == 1 # Need this assert cache_seqlens is None # not needed for vLLM - circular buffer # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' out = x stride_w_dim, stride_w_width = weight.stride() stride_x_seq, stride_x_dim, stride_x_token = x.stride( ) # X (batch, dim, seqlen) stride_o_seq, stride_o_dim, stride_o_token = out.stride() stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride( ) stride_state_indices = (conv_state_indices.stride(0) if conv_state_indices is not None else 0) state_len = width - 1 + (seqlen - 1) # effective state_len needed np2_statelen = triton.next_power_of_2(state_len) def grid(META): return ( batch, triton.cdiv(dim, META["BLOCK_N"]), ) # prepare intermediate buffer strides if provided if intermediate_conv_window is not None: stride_inter_seq, stride_inter_step, stride_inter_dim, stride_inter_win = ( intermediate_conv_window.stride(0), intermediate_conv_window.stride(1), intermediate_conv_window.stride(2), intermediate_conv_window.stride(3), ) else: stride_inter_seq = stride_inter_step = stride_inter_dim = stride_inter_win = 0 _causal_conv1d_update_kernel[grid]( # Pointers to matrices x, weight, bias, conv_state, cache_seqlens, conv_state_indices, num_accepted_tokens, intermediate_conv_window if intermediate_conv_window is not None else x, out, # Matrix dimensions batch, dim, seqlen, state_len, num_cache_lines, # stride stride_x_seq, stride_x_dim, stride_x_token, stride_w_dim, stride_w_width, stride_istate_seq, stride_istate_dim, stride_istate_token, stride_state_indices, stride_inter_seq, stride_inter_step, stride_inter_dim, stride_inter_win, stride_o_seq, stride_o_dim, stride_o_token, # others pad_slot_id, # META HAS_BIAS=bias is not None, KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], IS_CONTINUOUS_BATCHING=conv_state_indices is not None, IS_SPEC_DECODING=num_accepted_tokens is not None, NP2_STATELEN=np2_statelen, USE_PAD_SLOT=pad_slot_id is not None, BLOCK_N=128, SAVE_INTERMEDIATE=intermediate_conv_window is not None, ) if unsqueeze: out = out.squeeze(-1) return out