1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
|
template < typename Shape_, typename Policy_, int Stages, typename Enable = bool> class MmaBase { public: using Shape = Shape_;
using Policy = Policy_;
using Operator = typename Policy::Operator;
using WarpGemm = typename Policy::Operator::Shape;
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
static int const kStages = Stages;
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
static_assert(kWarpGemmIterations > 1, "The pipelined structure requires at least two warp-level " "GEMM operations.");
static_assert((kWarpGemmIterations % 2) == 0, "Inner loop iteration must be an even number.");
class SharedStorage { public:
using ShapeA = MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow, Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
using ShapeB = MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow, Shape::kN + Policy::SmemPaddingB::kColumn>;
public:
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
public:
CUTLASS_DEVICE static typename Operator::LayoutA LayoutA() { return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); }
CUTLASS_HOST_DEVICE static typename Operator::LayoutB LayoutB() { return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); }
CUTLASS_HOST_DEVICE TensorRefA operand_A_ref() { return TensorRefA{operand_A.data(), LayoutA()}; }
CUTLASS_HOST_DEVICE TensorRefB operand_B_ref() { return TensorRefB{operand_B.data(), LayoutB()}; } };
protected:
typename Operator::IteratorA warp_tile_iterator_A_;
typename Operator::IteratorB warp_tile_iterator_B_;
public:
CUTLASS_DEVICE MmaBase( SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx ): warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {
} };
|