@@ -355,7 +355,7 @@ int main(int argc, char **argv) {
355355 * @brief get action with input State with mainNet
356356 */
357357 nntrainer::Tensor in_tensor;
358- nntrainer::sharedTensor test;
358+ nntrainer::sharedConstTensor test;
359359 try {
360360 in_tensor = nntrainer::Tensor ({input});
361361 } catch (...) {
@@ -372,7 +372,7 @@ int main(int argc, char **argv) {
372372 targetNet.finalize ();
373373 return 0 ;
374374 }
375- float *data = test->getData ();
375+ const float *data = test->getData ();
376376 unsigned int len = test->getDim ().getDataLen ();
377377 std::vector<float > temp (data, data + len);
378378 action.push_back (argmax (temp));
@@ -474,7 +474,7 @@ int main(int argc, char **argv) {
474474 /* *
475475 * @brief run forward propagation with mainNet
476476 */
477- nntrainer::sharedTensor Q;
477+ nntrainer::sharedConstTensor Q;
478478 try {
479479 Q = mainNet.forwarding (MAKE_SHARED_TENSOR (q_in));
480480 } catch (...) {
@@ -487,7 +487,7 @@ int main(int argc, char **argv) {
487487 /* *
488488 * @brief run forward propagation with targetNet
489489 */
490- nntrainer::sharedTensor NQ;
490+ nntrainer::sharedConstTensor NQ;
491491 try {
492492 NQ = targetNet.forwarding (MAKE_SHARED_TENSOR (nq_in));
493493 } catch (...) {
@@ -496,22 +496,23 @@ int main(int argc, char **argv) {
496496 targetNet.finalize ();
497497 return -1 ;
498498 }
499- float *nqa = NQ->getData ();
499+ const float *nqa = NQ->getData ();
500500
501501 /* *
502502 * @brief Update Q values & udpate mainNetwork
503503 */
504+ nntrainer::Tensor tempQ = *Q;
504505 for (unsigned int i = 0 ; i < in_Exp.size (); i++) {
505506 if (in_Exp[i].done ) {
506- Q-> setValue (i, 0 , 0 , (int )in_Exp[i].action [0 ],
507- (float )in_Exp[i].reward );
507+ tempQ. setValue (i, 0 , 0 , (int )in_Exp[i].action [0 ],
508+ (float )in_Exp[i].reward );
508509 } else {
509510 float next = (nqa[i * NQ->getWidth ()] > nqa[i * NQ->getWidth () + 1 ])
510511 ? nqa[i * NQ->getWidth ()]
511512 : nqa[i * NQ->getWidth () + 1 ];
512513 try {
513- Q-> setValue (i, 0 , 0 , (int )in_Exp[i].action [0 ],
514- (float )in_Exp[i].reward + DISCOUNT * next);
514+ tempQ. setValue (i, 0 , 0 , (int )in_Exp[i].action [0 ],
515+ (float )in_Exp[i].reward + DISCOUNT * next);
515516 } catch (...) {
516517 std::cerr << " Error during set value" << std::endl;
517518 mainNet.finalize ();
0 commit comments