Skip to content

Commit

Permalink
add node_type enum
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewldesousa committed Jul 9, 2024
1 parent eba97de commit 04769f7
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,25 @@ class Scalar: public std::enable_shared_from_this<Scalar<T>> {
}

public:
enum class NodeType {
INPUT,
WEIGHT,
COMPUTED
};


const NodeType node_type;
const int id = id_counter++;
T value, grad = 0;
int in_degrees = 0;

std::set<std::shared_ptr<Scalar<T>>> children;
std::function<void()> _backward;

Scalar() : Scalar(0) {}
Scalar(T value) : value{value}, in_degrees{0} { _backward = []() {}; }
Scalar() : Scalar(0, NodeType::COMPUTED) {}
Scalar(T value, NodeType node_type) : value{value}, in_degrees{0}, node_type{node_type} { _backward = []() {}; }
Scalar(T value) : Scalar(value, NodeType::COMPUTED) {}


static std::shared_ptr<Scalar<T>> make(T value) {
auto s = std::make_shared<Scalar<T>>(value);
Expand Down

0 comments on commit 04769f7

Please sign in to comment.