Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions api/ccapi/include/layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ enum LayerType {
LAYER_TRANSPOSE = ML_TRAIN_LAYER_TYPE_TRANSPOSE, /**< Transpose Layer type */
LAYER_CHANNEL_SHUFFLE =
ML_TRAIN_LAYER_TYPE_CHANNEL_SHUFFLE, /**< Channel Shuffle Layer type */
LAYER_TOPK = ML_TRAIN_TYPE_TOPK, /**< Topk Layer type */
LAYER_UNKNOWN = ML_TRAIN_LAYER_TYPE_UNKNOWN /**< Unknown */
};

Expand Down Expand Up @@ -705,6 +706,14 @@ Upsample2D(const std::vector<std::string> &properties = {}) {
return createLayer(LayerType::LAYER_UPSAMPLE2D, properties);
}

/**
* @brief Helper function to create Topk layer
*/
inline std::unique_ptr<Layer>
TopkLayer(const std::vector<std::string> &properties = {}) {
return createLayer(LayerType::LAYER_TOPK, properties);
}

/**
* @brief Helper function to create activation layer
*/
Expand Down
1 change: 1 addition & 0 deletions api/nntrainer-api-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ typedef enum {
ML_TRAIN_LAYER_TYPE_CHANNEL_SHUFFLE =
45, /**< Channel Shuffle Layer type (Since 9.0)*/
ML_TRAIN_LAYER_TYPE_NEG = 46, /**< Neg Layer type (Since 9.0)*/
ML_TRAIN_TYPE_TOPK = 49, /**< Topk Layer type (Since 9.0)*/
ML_TRAIN_LAYER_TYPE_PREPROCESS_FLIP =
300, /**< Preprocess flip Layer (Since 6.5) */
ML_TRAIN_LAYER_TYPE_PREPROCESS_TRANSLATE =
Expand Down
4 changes: 4 additions & 0 deletions nntrainer/app_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
#include <tangent_layer.h>
#include <tensor_layer.h>
#include <time_dist.h>
#include <topk_layer.h>
#include <upsample2d_layer.h>
#include <weight_layer.h>
#include <zoneout_lstmcell.h>
Expand Down Expand Up @@ -373,6 +374,9 @@ void AppContext::add_default_object() {
registerFactory(nntrainer::createLayer<ChannelShuffle>, ChannelShuffle::type,
LayerType::LAYER_CHANNEL_SHUFFLE);

registerFactory(nntrainer::createLayer<TopkLayer>, TopkLayer::type,
LayerType::LAYER_TOPK);

#ifdef ENABLE_NNSTREAMER_BACKBONE
registerFactory(nntrainer::createLayer<NNStreamerLayer>,
NNStreamerLayer::type, LayerType::LAYER_BACKBONE_NNSTREAMER);
Expand Down
10 changes: 10 additions & 0 deletions nntrainer/layers/common_properties.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,16 @@ class Axis : public nntrainer::PositiveIntegerProperty {
bool isValid(const unsigned int &value) const override;
};

/**
* @brief K property, select k elements in topk layer.
*
*/
class K : public nntrainer::PositiveIntegerProperty {
public:
static constexpr const char *key = "k"; /**< unique key to access */
using prop_tag = uint_prop_tag; /**< property type */
};

/**
* @brief StartDimension property, start dimension to be flatten
*
Expand Down
3 changes: 2 additions & 1 deletion nntrainer/layers/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ layer_sources = [
'positional_encoding_layer.cpp',
'identity_layer.cpp',
'upsample2d_layer.cpp',
'channel_shuffle.cpp'
'channel_shuffle.cpp',
'topk_layer.cpp'
]

layer_headers = [
Expand Down
93 changes: 93 additions & 0 deletions nntrainer/layers/topk_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2025 Sachin Singh <[email protected]>
*
* @file topk_layer.cpp
* @date 28 July 2025
* @see https://github.com/nnstreamer/nntrainer
* @author Sachin Singh <[email protected]>
* @bug No known bugs except for NYI items
* @brief This is Topk Layer Class for Neural Network
*
*/

#include <layer_context.h>
#include <nntrainer_error.h>
#include <nntrainer_log.h>
#include <node_exporter.h>
#include <topk_layer.h>
namespace nntrainer {

static constexpr size_t SINGLE_INOUT_IDX = 0;

void TopkLayer::finalize(InitLayerContext &context) {
NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument)
<< "Topk only supports 1 input";

auto &KCount = std::get<props::K>(topk_props);

NNTR_THROW_IF(KCount.empty(), std::invalid_argument)
<< "k value not set in Topk layer";

unsigned int k = KCount.get();

NNTR_THROW_IF(k == 0 || k > context.getInputDimensions()[0].width(),
std::invalid_argument)
<< "k value is invalid in Topk layer. k is " << k
<< ". It should be in range [1," << context.getInputDimensions()[0].width()
<< "]";

TensorDim out_dim = context.getInputDimensions()[0];
TensorDim idx_dim = context.getInputDimensions()[0];

out_dim.width(k);
idx_dim.width(k);

out_dim.setDataType(context.getActivationDataType());
context.setOutputDimensions({out_dim, idx_dim});
}

void TopkLayer::forwarding(RunLayerContext &context, bool training) {

unsigned int k = std::get<props::K>(topk_props).get();

auto [output, indices] = context.getInput(0).topK(k);

context.getOutput(0).copy(output);
context.getOutput(1).copy(indices);
}

void TopkLayer::calcDerivative(RunLayerContext &context) {

auto output = context.getIncomingDerivative(0);
auto indices = context.getOutput(1);

for (unsigned int b = 0; b < output.batch(); ++b) {
for (unsigned int c = 0; c < output.channel(); ++c) {
for (unsigned int h = 0; h < output.height(); ++h) {
for (unsigned int w = 0; w < output.width(); ++w) {

auto u = indices.getValue<uint32_t>(b, c, h, w);
auto val = output.getValue(b, c, h, w);
context.getOutgoingDerivative(0).setValue(b, c, h, u, val);
}
}
}
}
}

void TopkLayer::setProperty(const std::vector<std::string> &values) {
auto remain_props = loadProperties(values, topk_props);
if (!remain_props.empty()) {
std::string msg = "[TopkLayer] Unknown Layer Properties count " +
std::to_string(remain_props.size());
throw exception::not_supported(msg);
}
}

void TopkLayer::exportTo(Exporter &exporter,
const ml::train::ExportMethods &method) const {
exporter.saveResult(topk_props, method, this);
}

} /* namespace nntrainer */
107 changes: 107 additions & 0 deletions nntrainer/layers/topk_layer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2025 Sachin Singh <[email protected]>
*
* @file topk_layer.h
* @date 28 July 2025
* @see https://github.com/nnstreamer/nntrainer
* @author Sachin Singh <[email protected]>
* @bug No known bugs except for NYI items
* @brief This is Topk Layer Class for Neural Network
*
*/

#ifndef __TOPK_LAYER_H__
#define __TOPK_LAYER_H__
#ifdef __cplusplus

#include <common_properties.h>
#include <layer_devel.h>

namespace nntrainer {

/**
* @class Topk Layer
* @brief Topk Layer
*/
class TopkLayer : public Layer {
public:
/**
* @brief Constructor of topk Layer
*/
TopkLayer() : Layer() {}

/**
* @brief Destructor of topk Layer
*/
~TopkLayer() = default;

/**
* @brief Move constructor of TopkLayer.
* @param[in] Topk &&
*/
TopkLayer(TopkLayer &&rhs) noexcept = default;

/**
* @brief Move assignment operator.
* @parma[in] rhs TopkLayer to be moved.
*/
TopkLayer &operator=(TopkLayer &&rhs) = default;

/**
* @copydoc Layer::finalize(InitLayerContext &context)
*/
void finalize(InitLayerContext &context) override;

/**
* @copydoc Layer::forwarding(RunLayerContext &context, bool training)
*/
void forwarding(RunLayerContext &context, bool training) override;

/**
* @copydoc Layer::calcDerivative(RunLayerContext &context)
*/
void calcDerivative(RunLayerContext &context) override;

/**
* @copydoc Layer::setProperty(const std::vector<std::string> &values)
*/
void setProperty(const std::vector<std::string> &values) override;

/**
* @copydoc bool supportBackwarding() const
*/
bool supportBackwarding() const override { return true; };

/**
* @brief Initialize the in-place settings of the layer
* @return InPlaceType
*/
InPlaceType initializeInPlace() override {
is_inplace = true;
return InPlaceType::RESTRICTING;
}

/**
* @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods
* method)
*/
void exportTo(Exporter &exporter,
const ml::train::ExportMethods &method) const override;

/**
* @copydoc Layer::getType()
*/
const std::string getType() const override { return TopkLayer::type; };

static constexpr const char *type = "topk";

protected:
std::tuple<props::Print, props::K>
topk_props; /**< topk properties : k for topk */
};

} // namespace nntrainer

#endif /* __cplusplus */
#endif /* __TOPK_LAYER_H__ */
1 change: 1 addition & 0 deletions nntrainer/utils/node_exporter.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ class LossScaleForMixed;
class InPlaceProp;
class InPlaceDirectionProp;
class Exponent;
class K;
} // namespace props

class LayerNode;
Expand Down
Binary file modified packaging/unittest_models_v2.tar.gz
Binary file not shown.
25 changes: 25 additions & 0 deletions test/input_gen/genModelTests_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,19 @@ def forward(self, inputs, labels):

return out, loss

class topkOperation(torch.nn.Module):
def __init__(self,k):
super().__init__()
self.fc=torch.nn.Linear(2,7)
self.k = k
self.loss=torch.nn.MSELoss()

def forward(self,input,labels):
out = self.fc(input[0])
out = torch.topk(out,self.k)
loss = self.loss(out[0],labels[0])
return out, loss


if __name__ == "__main__":
record_v2(
Expand Down Expand Up @@ -1042,3 +1055,15 @@ def forward(self, inputs, labels):

# Function to check the created golden test file
inspect_file("channel_shuffle.nnmodelgolden")

topk_operation = topkOperation(k=4)
record_v2(
topk_operation,
iteration=2,
input_dims=[(2,2)],
input_dtype=[float],
label_dims=[(2,4)],
name="topk_operation"
)

inspect_file("topk_operation.nnmodelgolden")
3 changes: 2 additions & 1 deletion test/unittest/layers/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ test_target = [
'unittest_layers_mol_attention.cpp',
'unittest_layers_multi_head_attention.cpp',
'unittest_layers_positional_encoding.cpp',
'unittest_layers_upsample2d.cpp'
'unittest_layers_upsample2d.cpp',
'unittest_layers_topk.cpp'
]

if get_option('enable-opencl')
Expand Down
23 changes: 23 additions & 0 deletions test/unittest/layers/unittest_layers_topk.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2025 Sachin Singh <[email protected]>
*
* @file unittest_layers_topk.cpp
* @date 28 July 2025
* @brief Topk Layer Test
* @see https://github.com/nnstreamer/nntrainer
* @author Sachin Singh <[email protected]>
* @bug No known bugs except for NYI items
*/
#include <tuple>

#include <gtest/gtest.h>

#include <layers_common_tests.h>
#include <topk_layer.h>

auto semantic_topk_k1 = LayerSemanticsParamType(
nntrainer::createLayer<nntrainer::TopkLayer>, nntrainer::TopkLayer::type,
{"k=1"}, LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 1);

GTEST_PARAMETER_TEST(Topk, LayerSemantics, ::testing::Values(semantic_topk_k1));
23 changes: 23 additions & 0 deletions test/unittest/models/unittest_models.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,27 @@ static std::unique_ptr<NeuralNetwork> makeChannelShuffleOperation() {
return nn;
}

std::unique_ptr<NeuralNetwork> makeTopkOperation() {

std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());

auto outer_graph =

makeGraph({{"input", {"name=in", "input_shape=2:2"}},
{"fully_connected", {"name=fc", "unit=7", "input_layers=in"}},
{"topk", {"name=topk_layer", "k=4", "input_layers=fc"}},
{"mse", {"name=loss", "input_layers=topk_layer(0)"}}});

for (auto &node : outer_graph) {
nn->addLayer(node);
}

nn->setProperty({"batch_size=1"});
nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate=0.1"}));

return nn;
}

GTEST_PARAMETER_TEST(
model, nntrainerModelTest,
::testing::ValuesIn({
Expand Down Expand Up @@ -1191,6 +1212,8 @@ GTEST_PARAMETER_TEST(
ModelTestOption::ALL_V2),
mkModelTc_V2(makeChannelShuffleOperation, "channel_shuffle",
ModelTestOption::ALL_V2),
mkModelTc_V2(makeTopkOperation, "topk_operation",
ModelTestOption::COMPARE_V2),
}),
[](const testing::TestParamInfo<nntrainerModelTest::ParamType> &info)
-> const auto & { return std::get<1>(info.param); });
Expand Down