Skip to content

Commit 3f187ea

Browse files
Akshaya Purohitcopybara-github
authored andcommitted
No public description
PiperOrigin-RevId: 709162113 Change-Id: I0f6e2ae742dd1b574d1b6d0fb2f8e9807c685fe1
1 parent 884aac9 commit 3f187ea

File tree

2 files changed

+155
-0
lines changed

2 files changed

+155
-0
lines changed

qkeras/registry.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
"""General purpose registy for registering classes or functions.
17+
18+
The registry can be used along with decorators to record any class/function.
19+
20+
Sample usage:
21+
# Setup registry with decorator.
22+
_REGISTRY = registry.Registry()
23+
def register(cls):
24+
_REGISTRY.register(cls)
25+
def lookup(name):
26+
return _REGISTRY.lookup(name)
27+
28+
# Register instances.
29+
@register
30+
def foo_task():
31+
...
32+
33+
@register
34+
def bar_task():
35+
...
36+
37+
# Retrieve instances.
38+
def my_executor():
39+
...
40+
my_task = lookup("foo_task")
41+
...
42+
"""
43+
44+
45+
class Registry(object):
46+
"""A registry class to record class representations or function objects."""
47+
48+
def __init__(self):
49+
"""Initializes the registry."""
50+
self._container = {}
51+
52+
def register(self, item, name=None):
53+
"""Register an item.
54+
55+
Args:
56+
item: Python item to be recorded.
57+
name: Optional name to be used for recording item. If not provided,
58+
item.__name__ is used.
59+
"""
60+
if not name:
61+
name = item.__name__
62+
self._container[name] = item
63+
64+
def lookup(self, name):
65+
"""Retrieves an item from the registry.
66+
67+
Args:
68+
name: Name of the item to lookup.
69+
70+
Returns:
71+
Registered item from the registry.
72+
"""
73+
return self._container[name]

tests/registry_test.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
"""Unit tests for registry."""
17+
18+
from numpy.testing import assert_equal
19+
from numpy.testing import assert_raises
20+
import pytest
21+
22+
from qkeras import registry
23+
24+
25+
def sample_function(arg):
26+
"""Sample function for testing."""
27+
return arg
28+
29+
30+
class SampleClass(object):
31+
"""Sample class for testing."""
32+
33+
def __init__(self, arg):
34+
self._arg = arg
35+
36+
def get_arg(self):
37+
return self._arg
38+
39+
40+
def test_register_function():
41+
reg = registry.Registry()
42+
reg.register(sample_function)
43+
registered_function = reg.lookup('sample_function')
44+
# Call the function to validate.
45+
assert_equal(registered_function, sample_function)
46+
47+
48+
def test_register_class():
49+
reg = registry.Registry()
50+
reg.register(SampleClass)
51+
registered_class = reg.lookup('SampleClass')
52+
# Create and call class object to validate.
53+
assert_equal(SampleClass, registered_class)
54+
55+
56+
def test_register_with_name():
57+
reg = registry.Registry()
58+
name = 'NewSampleClass'
59+
reg.register(SampleClass, name=name)
60+
registered_class = reg.lookup(name)
61+
# Create and call class object to validate.
62+
assert_equal(SampleClass, registered_class)
63+
64+
65+
def test_lookup_missing_item():
66+
reg = registry.Registry()
67+
assert_raises(KeyError, reg.lookup, 'foo')
68+
69+
70+
def test_lookup_missing_name():
71+
reg = registry.Registry()
72+
sample_class = SampleClass(arg=1)
73+
# objects don't have a default __name__ attribute.
74+
assert_raises(AttributeError, reg.register, sample_class)
75+
76+
# check that the object can be retrieved with a registered name.
77+
reg.register(sample_class, 'sample_class')
78+
assert_equal(sample_class, reg.lookup('sample_class'))
79+
80+
81+
if __name__ == '__main__':
82+
pytest.main([__file__])

0 commit comments

Comments
 (0)