1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
| struct SM80_16x8x8_F32F16F16F32_TN { using DRegisters = float[4]; using ARegisters = uint32_t[2]; using BRegisters = uint32_t[1]; using CRegisters = float[4];
CUTE_HOST_DEVICE static void fma(float & d0, float & d1, float & d2, float & d3, uint32_t const& a0, uint32_t const& a1, uint32_t const& b0, float const & c0, float const & c1, float const & c2, float const & c3) { asm volatile( "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " "{%0, %1, %2, %3}," "{%4, %5}," "{%6}," "{%7, %8, %9, %10};\n" : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) : "r"(a0), "r"(a1), "r"(b0), "f"(c0), "f"(c1), "f"(c2), "f"(c3)); } };
using SM80_8x4 = Layout<Shape <Shape < _4,_8>,_1>, Stride<Stride< _8,_1>,_0>>;
using SM80_8x8_Row = Layout<Shape <Shape < _4,_8>,_2>, Stride<Stride<_16,_1>,_8>>;
using SM80_8x16_Row = Layout<Shape <Shape < _4,_8>,_4>, Stride<Stride<_32,_1>,_8>>;
using SM80_16x8_Row = Layout<Shape <Shape < _4,_8>,Shape < _2,_2>>, Stride<Stride<_32,_1>,Stride<_16,_8>>>;
template <> struct MMA_Traits<SM80_16x8x8_F16F16F16F16_TN> { using ValTypeD = half_t; using ValTypeA = half_t; using ValTypeB = half_t; using ValTypeC = half_t;
using Shape_MNK = Shape<_16,_8,_8>; using ThrID = Layout<_32>; using ALayout = SM80_16x8_Row; using BLayout = SM80_8x8_Row; using CLayout = SM80_16x8_Row; };
template <> struct MMA_Traits<SM80_16x8x8_F32F16F16F32_TN> : MMA_Traits<SM80_16x8x8_F16F16F16F16_TN> { using ValTypeD = float; using ValTypeA = half_t; using ValTypeB = half_t; using ValTypeC = float; };
|