cutlass gemm

cutlass code

https://github.com/ginsongsong/CUDA_cutlass_benchmark/blob/master/allgemm.cu

代码解析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/// Policy object describing MmaTensorOp
template <
/// Warp-level GEMM operator (concept: gemm::warp::Mma)
typename Operator_,
/// Padding used for A operand in shared memory (concept: MatrixShape)
typename SmemPaddingA_,
/// Padding used for B operand in shared memory (concept: MatrixShape)
typename SmemPaddingB_,
/// Number of partitions of K dimension of GEMM
int PartitionsK = 1>
struct MmaPolicy {
/// Warp-level GEMM operator (concept: gemm::warp::MmaTensorOp or gemm::warp::MmaSimt)
using Operator = Operator_;

/// Padding used for A operand in shared memory
using SmemPaddingA = SmemPaddingA_;

/// Padding used for B operand in shared memory
using SmemPaddingB = SmemPaddingB_;

/// Number of partitions of K dimension
static int const kPartitionsK = PartitionsK;
};
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Used for partial specialization
typename Enable = bool>
class MmaBase {
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;

///< Policy describing tuning details
using Policy = Policy_;

//
// Dependent types
//

/// Warp-level Mma
using Operator = typename Policy::Operator;

/// Shape describing the overall GEMM computed from shared memory
/// by each warp.
using WarpGemm = typename Policy::Operator::Shape;

/// Shape describing the number of warps filling the CTA
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM,
Shape::kN / WarpGemm::kN,
Shape::kK / WarpGemm::kK>;

/// Number of warp-level GEMM oeprations
static int const kWarpGemmIterations =
(WarpGemm::kK / Operator::Policy::MmaShape::kK);

/// Number of stages
static int const kStages = Stages;

/// Tensor reference to the A operand
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;

/// Tensor reference to the B operand
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;

static_assert(kWarpGemmIterations > 1,
"The pipelined structure requires at least two warp-level "
"GEMM operations.");

static_assert((kWarpGemmIterations % 2) == 0,
"Inner loop iteration must be an even number.");

//
// Nested structs
//

/// Shared storage object needed by threadblock-scoped GEMM
class SharedStorage {
public:
//
// Type definitions
//

/// Shape of the A matrix operand in shared memory
using ShapeA = MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
Shape::kK * kStages +
Policy::SmemPaddingA::kColumn>;

/// Shape of the B matrix operand in shared memory
using ShapeB =
MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
Shape::kN + Policy::SmemPaddingB::kColumn>;

public:
//
// Data members
//

/// Buffer for A operand
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;

/// Buffer for B operand
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;

public:

//
// Methods
//

/// Returns a layout object for the A matrix
CUTLASS_DEVICE
static typename Operator::LayoutA LayoutA() {
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
}

/// Returns a layout object for the B matrix
CUTLASS_HOST_DEVICE
static typename Operator::LayoutB LayoutB() {
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
}

/// Returns a TensorRef to the A operand
CUTLASS_HOST_DEVICE
TensorRefA operand_A_ref() {
return TensorRefA{operand_A.data(), LayoutA()};
}

/// Returns a TensorRef to the B operand
CUTLASS_HOST_DEVICE
TensorRefB operand_B_ref() {
return TensorRefB{operand_B.data(), LayoutB()};
}
};

protected:

//
// Data members
//

/// Iterator to load a warp-scoped tile of A operand from shared memory
typename Operator::IteratorA warp_tile_iterator_A_;

/// Iterator to load a warp-scoped tile of B operand from shared memory
typename Operator::IteratorB warp_tile_iterator_B_;

public:

/// Construct from tensor references
CUTLASS_DEVICE
MmaBase(
///< Shared storage needed for internal use by threadblock-scoped GEMM
SharedStorage &shared_storage,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx
):
warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {

}
};