Fix Inv2d contraction and align implementation with Involution paper #18
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
This PR fixes the core contraction in
Inv2dand aligns the implementation more closely with the Involution paper. It also adds basic validation and a small stabilization step in the kernel generator path.What was wrong before
The original implementation used
kernel @ x_unfoldedwhere:kernelhad shape(B, groups, 1, K, H, W)x_unfoldedhad shape(B, groups, group_ch, K, H, W)torch.matmulalways treats the last two dimensions as matrix dimensions, so this was effectively multiplying over the spatial dimensions(H, W)instead of over the kernel dimensionK(whereK = kernel_size * kernel_size).As a result:
H == W(square feature maps), because matmul required those spatial dimensions to match.The code also assumed
channels % group_ch == 0when reshaping, but never validated this, making it easy to misconfigure.The kernel generator (
reduce→span) had no normalization or nonlinearity, which makes stacked dynamic kernels brittle during training.What this PR changes
Replaces the incorrect
kernel @ x_unfoldedwith a proper contraction over the kernel dimension usingtorch.einsum, without mixing spatial dimensions:(B, groups, K, H_out, W_out).(B, groups, group_ch, K, H_out, W_out).einsum("bgkij,bgckij->bgcij", kernel, patches)and then reshaped back to(B, C, H_out, W_out).Adds a lightweight "sigma" mapping in the kernel generator:
reduce(1x1 conv) →BatchNorm2d→ReLU→span(1x1 conv).Validates that:
channelspassed toInv2d.channelsis divisible bygroup_ch.Uses
reshapeinstead ofviewwhere needed so the code works correctly with non-contiguous tensors (for example, tensors produced byeinsum).Behavior and API
The public API of
Inv2dis unchanged:Inv2d(channels, kernel_size, stride, group_ch=16, red_ratio=2, **kwargs)(B, C, H, W)(B, C, H_out, W_out)with the same stride and padding behavior as before.Existing code that imports and constructs
Inv2dshould continue to work, but the operation is now:Inv2dlayers.