diff --git a/scalar.h b/scalar.h index b1bf070..0fd338b 100644 --- a/scalar.h +++ b/scalar.h @@ -20,6 +20,14 @@ class Scalar: public std::enable_shared_from_this> { } public: + enum class NodeType { + INPUT, + WEIGHT, + COMPUTED + }; + + + const NodeType node_type; const int id = id_counter++; T value, grad = 0; int in_degrees = 0; @@ -27,8 +35,10 @@ class Scalar: public std::enable_shared_from_this> { std::set>> children; std::function _backward; - Scalar() : Scalar(0) {} - Scalar(T value) : value{value}, in_degrees{0} { _backward = []() {}; } + Scalar() : Scalar(0, NodeType::COMPUTED) {} + Scalar(T value, NodeType node_type) : value{value}, in_degrees{0}, node_type{node_type} { _backward = []() {}; } + Scalar(T value) : Scalar(value, NodeType::COMPUTED) {} + static std::shared_ptr> make(T value) { auto s = std::make_shared>(value);