From 04769f7d3d22433ab34fe3b511b03bf188f1b5d1 Mon Sep 17 00:00:00 2001 From: Andrew Date: Mon, 8 Jul 2024 21:23:45 -0700 Subject: [PATCH] add node_type enum --- scalar.h | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/scalar.h b/scalar.h index b1bf070..0fd338b 100644 --- a/scalar.h +++ b/scalar.h @@ -20,6 +20,14 @@ class Scalar: public std::enable_shared_from_this> { } public: + enum class NodeType { + INPUT, + WEIGHT, + COMPUTED + }; + + + const NodeType node_type; const int id = id_counter++; T value, grad = 0; int in_degrees = 0; @@ -27,8 +35,10 @@ class Scalar: public std::enable_shared_from_this> { std::set>> children; std::function _backward; - Scalar() : Scalar(0) {} - Scalar(T value) : value{value}, in_degrees{0} { _backward = []() {}; } + Scalar() : Scalar(0, NodeType::COMPUTED) {} + Scalar(T value, NodeType node_type) : value{value}, in_degrees{0}, node_type{node_type} { _backward = []() {}; } + Scalar(T value) : Scalar(value, NodeType::COMPUTED) {} + static std::shared_ptr> make(T value) { auto s = std::make_shared>(value);