1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
| class OpLiteRegistrar { public: OpLiteRegistrar(const std::string& op_type, std::function<std::shared_ptr<OpLite>()> fun) { OpLiteFactory::Global().RegisterCreator(op_type, fun); } // Touch function is used to guarantee registrar was initialized. void touch() {} };
class KernelFactory { public: // Register a function to create kernels void RegisterCreator(const std::string& op_type, TargetType target, PrecisionType precision, DataLayoutType layout, std::function<std::unique_ptr<KernelBase>()> fun) { op_registry_[op_type][std::make_tuple(target, precision, layout)].push_back( fun); }
static KernelFactory& Global() { static KernelFactory* x = new KernelFactory; return *x; }
/** * Create all kernels belongs to an op. */ std::list<std::unique_ptr<KernelBase>> Create(const std::string& op_type) { std::list<std::unique_ptr<KernelBase>> res; if (op_registry_.find(op_type) == op_registry_.end()) return res; auto& kernel_registry = op_registry_[op_type]; for (auto it = kernel_registry.begin(); it != kernel_registry.end(); ++it) { for (auto& fun : it->second) { res.emplace_back(fun()); } } return res; }
/** * Create a specific kernel. Return a list for API compatible. */ std::list<std::unique_ptr<KernelBase>> Create(const std::string& op_type, TargetType target, PrecisionType precision, DataLayoutType layout) { std::list<std::unique_ptr<KernelBase>> res; if (op_registry_.find(op_type) == op_registry_.end()) return res; auto& kernel_registry = op_registry_[op_type]; auto it = kernel_registry.find(std::make_tuple(target, precision, layout)); if (it == kernel_registry.end()) return res; for (auto& fun : it->second) { res.emplace_back(fun()); } return res; }
protected: // Outer map: op -> a map of kernel. // Inner map: kernel -> creator function. // Each kernel was represented by a combination of <TargetType, PrecisionType, // DataLayoutType> std::map<std::string, std::map<std::tuple<TargetType, PrecisionType, DataLayoutType>, std::list<std::function<std::unique_ptr<KernelBase>()>>>> op_registry_; };
// Register Kernel by initializing a static KernelRegistrar instance class KernelRegistrar { public: KernelRegistrar(const std::string& op_type, TargetType target, PrecisionType precision, DataLayoutType layout, std::function<std::unique_ptr<KernelBase>()> fun) { KernelFactory::Global().RegisterCreator( op_type, target, precision, layout, fun); } // Touch function is used to guarantee registrar was initialized. void touch() {} };
class ParamTypeDummyRegistry { public: struct NewInstance { NewInstance() {} NewInstance& BindInput(const std::string& arg_name, const ParamType& ptype) { return *this; } NewInstance& BindOutput(const std::string& arg_name, const ParamType& ptype) { return *this; } NewInstance& SetVersion(const std::string& version) { return *this; } NewInstance& BindPaddleOpVersion(const std::string& op_type, int32_t version_id) { return *this; } bool Finalize() { return true; } };
private: ParamTypeDummyRegistry() = default; };
|