Skip to content

Commit

Permalink
and example
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewldesousa committed Jul 3, 2024
1 parent 750197f commit 28a810a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
19 changes: 12 additions & 7 deletions apps/and.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
#include <iostream>
#include "../backprop.h"
#include <memory>
#include <vector>
#include "../backprop.h"

// move operator
#include <utility>


int main(int argc, char** argv) {
if (argc > 1 && std::string(argv[1]) == "debug") Logger::get_instance().set_debug_mode(true);
else if (argc > 1) throw std::runtime_error("Invalid argument in main function. Use 'debug' to enable debug mode.");

int num_epochs = 10000, num_samples = 4;
int num_epochs = 500000, num_samples = 4;
float learning_rate = 0.001;

// weights shared pointer
Expand Down Expand Up @@ -39,26 +42,28 @@ int main(int argc, char** argv) {
std::shared_ptr<Scalar<double>> loss = Scalar<double>::make(0);
for (int j = 0; j < num_samples; j++) {
auto z = Scalar<double>::make(0);

for (int k = 0; k < weights.size(); k++) {
if (k == 0) z = z + weights[k];
if (k == weights.size() - 1) z = z + weights[k];
else z = z + weights[k] * X[j][k];
}
auto a = sigmoid(z);
loss = cross_entropy(Y[j], a) + loss;
}

loss = loss / Scalar<double>::make(num_samples);
// loss = loss / Scalar<double>::make(num_samples);
loss->backward();

// update weights
for (int j = 0; j < weights.size(); j++) weights[j]->value -= learning_rate * weights[j]->grad;
for (int j = 0; j < weights.size(); j++) {
weights[j]->value -= learning_rate * weights[j]->grad;
weights[j]->grad = 0;
}

if (i % 1000 == 0) Logger::get_instance().log(
"Epoch: " + std::to_string(i) + "/Loss: " + std::to_string(loss->value),
Logger::LogLevel::INFO
);

break;
}

return 0;
Expand Down
2 changes: 2 additions & 0 deletions scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ class Scalar: public std::enable_shared_from_this<Scalar<T>> {
return result;
}



// dont allow inplace operations
Scalar<T>& operator+=(const Scalar<T>& rhs) = delete;
Scalar<T>& operator-=(const Scalar<T>& rhs) = delete;
Expand Down

0 comments on commit 28a810a

Please sign in to comment.