Skip to content

Conversation

@ethaniunu
Copy link

Summary

This PR fixes the core contraction in Inv2d and aligns the implementation more closely with the Involution paper. It also adds basic validation and a small stabilization step in the kernel generator path.


What was wrong before

  • The original implementation used kernel @ x_unfolded where:

    • kernel had shape (B, groups, 1, K, H, W)
    • x_unfolded had shape (B, groups, group_ch, K, H, W)
  • torch.matmul always treats the last two dimensions as matrix dimensions, so this was effectively multiplying over the spatial dimensions (H, W) instead of over the kernel dimension K (where K = kernel_size * kernel_size).

  • As a result:

    • The contraction did not match the involution operation described in the paper.
    • The code only worked at all when H == W (square feature maps), because matmul required those spatial dimensions to match.
  • The code also assumed channels % group_ch == 0 when reshaping, but never validated this, making it easy to misconfigure.

  • The kernel generator (reducespan) had no normalization or nonlinearity, which makes stacked dynamic kernels brittle during training.


What this PR changes

  • Replaces the incorrect kernel @ x_unfolded with a proper contraction over the kernel dimension using torch.einsum, without mixing spatial dimensions:

    • Kernels are reshaped to (B, groups, K, H_out, W_out).
    • Unfolded patches are reshaped to (B, groups, group_ch, K, H_out, W_out).
    • The output is computed with einsum("bgkij,bgckij->bgcij", kernel, patches) and then reshaped back to (B, C, H_out, W_out).
  • Adds a lightweight "sigma" mapping in the kernel generator:

    • reduce (1x1 conv) → BatchNorm2dReLUspan (1x1 conv).
    • This follows common Involution reference implementations and helps keep the dynamically generated kernels better behaved.
  • Validates that:

    • The input channel count matches channels passed to Inv2d.
    • channels is divisible by group_ch.
    • The unfolded tensor has the expected number of channels and spatial positions.
  • Uses reshape instead of view where needed so the code works correctly with non-contiguous tensors (for example, tensors produced by einsum).


Behavior and API

  • The public API of Inv2d is unchanged:

    • Constructor: Inv2d(channels, kernel_size, stride, group_ch=16, red_ratio=2, **kwargs)
    • Input shape: (B, C, H, W)
    • Output shape: (B, C, H_out, W_out) with the same stride and padding behavior as before.
  • Existing code that imports and constructs Inv2d should continue to work, but the operation is now:

    • Mathematically consistent with the intended involution operator.
    • More numerically stable when stacking multiple Inv2d layers.

@ethaniunu ethaniunu changed the title Clean up code Fix Inv2d contraction and align implementation with Involution paper Dec 5, 2025
@ethaniunu
Copy link
Author

Hi, I found that the current implementation does not work and actually significantly reduces accuracy on bounding box and segmentation tasks. This PR reflects the modified version of your code that aligns with the paper more closely and has been validated extensively in production at our company. We've also added an einsum implementation that reduces memory on high resolution layers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant