1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
| static inline size_t naive_conv_out_size(size_t in_size, size_t pad, size_t dilation, size_t ksize, size_t stride) { return (in_size + 2 * pad - dilation * (ksize - 1) - 1) / stride + 1; }
static inline void naive_conv_fwd_nchw(const float *src, const float *filter, float *dst, size_t n, size_t w, size_t h, size_t c, size_t k, size_t fx, size_t fy, size_t px, size_t py, size_t sx, size_t sy, size_t dx, size_t dy, size_t group) { size_t oh = naive_conv_out_size(h, py, dy, fy, sy); size_t ow = naive_conv_out_size(w, px, dx, fx, sx); assert((group >= 1) && (c % group == 0) && (k % group == 0)); size_t k_per_group = k / group; size_t c_per_group = c / group; size_t ig, in, ik, ioh, iow, ic, is, ir; size_t cur_h, cur_w, o_idx, i_idx, f_idx; // input:[n,c,h,w], filter:[k, c, fx, fy], output: [n, k, out_h, out_w] for (ig = 0; ig < group; ig++) { for (in = 0; in < n; in++) { for (ik = 0; ik < k_per_group; ik++) { for (ioh = 0; ioh < oh; ioh++) { for (iow = 0; iow < ow; iow++) { // sliding window for this filter float value = .0f; o_idx = in * k * oh * ow + ig * k_per_group * oh * ow + ik * oh * ow + ioh * ow + iow; for (ic = 0; ic < c_per_group; ic++) { for (ir = 0; ir < fy; ir++) { cur_h = sy * ioh - py + dy * ir; if (cur_h < 0 || cur_h >= h) continue; for (is = 0; is < fx; is++) { cur_w = sx * iow - px + dx * is; if (cur_w < 0 || cur_w >= w) continue; i_idx = in * c * h * w + ig * c_per_group * h * w + ic * h * w + cur_h * w + cur_w; f_idx = ig * k_per_group * c_per_group * fy * fx + ik * c_per_group * fy * fx + ic * fy * fx + ir * fx + is; value += src[i_idx] * filter[f_idx]; } } } dst[o_idx] = value; } } } } } }
// group = 1 static inline void naive_conv_fwd_nchw(const float *src, const float *filter, float *dst, size_t n, size_t w, size_t h, size_t c, size_t k, size_t fx, size_t fy, size_t px, size_t py, size_t sx, size_t sy, size_t dx, size_t dy, size_t group) { size_t oh = naive_conv_out_size(h, py, dy, fy, sy); size_t ow = naive_conv_out_size(w, px, dx, fx, sx); assert((group >= 1) && (c % group == 0) && (k % group == 0)); size_t k_per_group = k / group; size_t c_per_group = c / group; size_t ig, in, ik, ioh, iow, ic, is, ir; size_t cur_h, cur_w, o_idx, i_idx, f_idx; // input:[n,c,h,w], filter:[k, c, fx, fy], output: [n, k, out_h, out_w] for (ig = 0; ig < group; ig++) { for (in = 0; in < n; in++) { for (ik = 0; ik < k_per_group; ik++) { for (ioh = 0; ioh < oh; ioh++) { for (iow = 0; iow < ow; iow++) { // sliding window for this filter float value = .0f; o_idx = in * k * oh * ow + ig * k_per_group * oh * ow + ik * oh * ow + ioh * ow + iow; for (ic = 0; ic < c_per_group; ic++) { for (ir = 0; ir < fy; ir++) { cur_h = sy * ioh - py + dy * ir; if (cur_h < 0 || cur_h >= h) continue; for (is = 0; is < fx; is++) { cur_w = sx * iow - px + dx * is; if (cur_w < 0 || cur_w >= w) continue; i_idx = in * c * h * w + ig * c_per_group * h * w + ic * h * w + cur_h * w + cur_w; f_idx = ig * k_per_group * c_per_group * fy * fx + ik * c_per_group * fy * fx + ic * fy * fx + ir * fx + is; value += src[i_idx] * filter[f_idx]; } } } dst[o_idx] = value; } } } } } }
|