From 2312458e06294ea49acf72915d7129c28794bbc4 Mon Sep 17 00:00:00 2001 From: Ethan Ng Date: Thu, 2 Jul 2026 18:46:38 -0700 Subject: [PATCH] Fix generic im2row NHWC layout to match Python reference MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: The generic Cadence `im2row` kernel wrote its NHWC (`channel_last=True`) output in kernel-position-major order `[kp][c]`, but the operator contract — defined by the Python reference in `ref_implementations.py` via `torch.nn.functional.unfold` — is channel-major `[c][kp]`, i.e. column index `c*(kH*kW) + kh*kW + kw`. The conv-lowering pass `ReplaceConvWithIm2RowAndLinear` packs the matmul weights in the same `[c][kp]` order (permute `[OC,kH,kW,IC]` -> `[OC,IC,kH,kW]` -> `[OC,K]`), so the generic kernel's `[kp][c]` output was transposed relative to the weights. This rewrites the generic NHWC branch to write `[c][kp]` (per-channel scatter `data_col[i_col*channels_col + c*num_kp + kp]`) Differential Revision: D110508326 --- .../cadence/generic/operators/op_im2row.cpp | 32 ++++++++----------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/backends/cadence/generic/operators/op_im2row.cpp b/backends/cadence/generic/operators/op_im2row.cpp index 8c939c7ad5c..cbfd4499535 100644 --- a/backends/cadence/generic/operators/op_im2row.cpp +++ b/backends/cadence/generic/operators/op_im2row.cpp @@ -8,8 +8,6 @@ #include -#include - #include #ifndef DISABLE_ALWAYS_INLINE @@ -59,34 +57,32 @@ ALWAYS_INLINE void im2row_( // array of size (out_height * out_width) x channels_col const int32_t channels_col = channels * kernel_h * kernel_w; - // If the layout is NHWC, we can copy 'channels' worth of contiguous data - // points when performing im2row. + // If the layout is NHWC, the input data is contiguous per-pixel (H, W, C). + // The output layout must match torch.nn.functional.unfold, which is [c][kp]: + // output[c * num_kp + kp] for each output position. if (channels_last) { + const int32_t num_kp = kernel_h * kernel_w; // Iterate over the output domain for (int _h = 0; _h < out_height; ++_h) { for (int _w = 0; _w < out_width; ++_w) { int32_t i_col = _h * out_width + _w; - // Each point in the output domain is the result of applying a filter of - // size kernel_h x kernel_w x channels on the input. But since channels - // is contiguous, we will not explicitly have a loop for it. for (int _kh = 0; _kh < kernel_h; ++_kh) { int32_t h_im = _h * stride_h - pad_h + _kh * dilation_h; for (int _kw = 0; _kw < kernel_w; ++_kw) { int32_t w_im = _w * stride_w - pad_w + _kw * dilation_w; + int32_t kp = _kh * kernel_w + _kw; - // h_im and w_im are the actual height and width coordinates of the - // input tensor from where we need to copy 'channels' points. - const T* __restrict__ slice_im = - data_im + (h_im * width + w_im) * channels; - T* __restrict__ slice_col = data_col + i_col * channels_col + - (_kh * kernel_w + _kw) * channels; - // If the coordinates were within the input domain, we copy - // 'channels' contiguous values. Otherwise we will fill the output - // with 0's. if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { - memcpy(slice_col, slice_im, channels * sizeof(T)); + const T* __restrict__ pixel = + data_im + (h_im * width + w_im) * channels; + for (int _c = 0; _c < channels; ++_c) { + data_col[i_col * channels_col + _c * num_kp + kp] = pixel[_c]; + } } else { - std::fill_n(slice_col, channels, T(in_zero_point)); + for (int _c = 0; _c < channels; ++_c) { + data_col[i_col * channels_col + _c * num_kp + kp] = + static_cast(in_zero_point); + } } } }