Skip to content

Commit

Permalink
and example
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewldesousa committed Jun 23, 2024
1 parent 985a88e commit a918e18
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 45 deletions.
70 changes: 30 additions & 40 deletions examples/and.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,65 +4,55 @@


int main() {
// 4x2 dataset maxtrix shared pointer
std::shared_ptr<Scalar<double>> X[4][2];
X[0][0] = Scalar<double>::make(0);
X[0][1] = Scalar<double>::make(0);
X[1][0] = Scalar<double>::make(0);
X[1][1] = Scalar<double>::make(1);
X[2][0] = Scalar<double>::make(1);
X[2][1] = Scalar<double>::make(0);
X[3][0] = Scalar<double>::make(1);
X[3][1] = Scalar<double>::make(1);

// labels shared pointer
std::shared_ptr<Scalar<double>> Y[4];
Y[0] = Scalar<double>::make(0);
Y[1] = Scalar<double>::make(0);
Y[2] = Scalar<double>::make(0);
Y[3] = Scalar<double>::make(1);

// weights shared pointer
std::shared_ptr<Scalar<double>> w1 = Scalar<double>::make(-.09);
std::shared_ptr<Scalar<double>> w2 = Scalar<double>::make(.02);
std::shared_ptr<Scalar<double>> b = Scalar<double>::make(0);

int num_epochs = 10000, num_samples = 4;
float learning_rate = 0.001;
float learning_rate = 0.0001;
auto& graph = ComputationalGraph<double>::get_instance();

for (int i = 0; i < num_epochs; i++) {
std::shared_ptr<Scalar<double>> loss = std::make_shared<Scalar<double>>(0);
graph.clear();

std::shared_ptr<Scalar<double>> X[4][2];
X[0][0] = Scalar<double>::make(0);
X[0][1] = Scalar<double>::make(0);
X[1][0] = Scalar<double>::make(0);
X[1][1] = Scalar<double>::make(1);
X[2][0] = Scalar<double>::make(1);
X[2][1] = Scalar<double>::make(0);
X[3][0] = Scalar<double>::make(1);
X[3][1] = Scalar<double>::make(1);

// labels shared pointer
std::shared_ptr<Scalar<double>> Y[4];
Y[0] = Scalar<double>::make(0);
Y[1] = Scalar<double>::make(0);
Y[2] = Scalar<double>::make(0);
Y[3] = Scalar<double>::make(1);


std::shared_ptr<Scalar<double>> loss = std::make_shared<Scalar<double>>(0);
for (int j = 0; j < num_samples; j++) {
// forward
auto z = w1 * X[j][0];
loss = z;

break;
auto z = w1 * X[j][0] + w2 * X[j][1] + b;
auto a = sigmoid(z);
loss = cross_entropy(Y[j], a) + loss;
}

std::cout << "Epoch: " << i << " Loss: " << loss->value << std::endl;
loss->backward();

// print updating weights

// update weights
w1->value = w1->value - learning_rate * w1->grad;
w2->value = w2->value - learning_rate * w2->grad;
b->value = b->value - learning_rate * b->grad;

// print grads
// printing

std::cout << "Printing grads:\n";

std::cout << "w1 grad: " << w1->grad << std::endl;
std::cout << "w2 grad: " << w2->grad << std::endl;
std::cout << "b grad: " << b->grad << std::endl;

// reset gradients
w1->grad = 0;
w2->grad = 0;
b->grad = 0;
// print epoch loss
if (i % 1000 == 0) {
std::cout << "Epoch: " << i << " Loss: " << loss->value << std::endl;
}
}

return 0;
Expand Down
14 changes: 9 additions & 5 deletions scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,17 @@ class ComputationalGraph {
return instance;
}

void backward() {
for (auto node : nodes) { if (node->in_degrees == 0) { node->backward(); } }
}

void add_node(std::shared_ptr<Scalar<T>> node) { nodes.insert(node); }

void clear() { nodes.clear(); }
void clear() {
for (auto node : nodes) {
node->grad = 0;
node->in_degrees = 0;
node->children.clear();
}

nodes.clear();
}
};


Expand Down

0 comments on commit a918e18

Please sign in to comment.