Skip to content

Commit

Permalink
add: unittest (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
akirakoyasu authored Jan 13, 2018
1 parent 9d35cd8 commit ead5ccf
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 35 deletions.
32 changes: 3 additions & 29 deletions aws_paramstore_py/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import argparse
import json

import boto3
from .main import get

__version__ = '0.0.3'
__version__ = '0.0.4'


def main():
def cli():
parser = argparse.ArgumentParser(description='Query params from AWS System Manager Parameter Store')
parser.add_argument('paths', metavar='path', nargs='*', help='The hierarchy for the parameter')
parser.add_argument('--decryption', action='store_true', help='Decrypt secure values or not')
Expand All @@ -16,29 +16,3 @@ def main():
decryption = args.decryption
params = get(*paths, decryption=decryption)
print(json.dumps(params))


def get(*paths, decryption=False):
ssm = boto3.client('ssm')
path = '/'.join(paths)
path = _complement_slashes(path)
response = ssm.get_parameters_by_path(Path=path, Recursive=True, WithDecryption=decryption)
params = map(lambda p: _remove_prefix(p, path), response['Parameters'])
return _convert_to_dict(params)


def _complement_slashes(path):
if path[len(path):] != '/':
path = path + '/'
if path[:1] != '/':
path = '/' + path
return path


def _remove_prefix(param, prefix):
param['Name'] = param['Name'][len(prefix):]
return param


def _convert_to_dict(params):
return {p['Name']: p['Value'] for p in params}
42 changes: 42 additions & 0 deletions aws_paramstore_py/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import boto3


def get(*paths, decryption=False):
ssm = boto3.client('ssm')
path = _join_slashes(paths)
response = ssm.get_parameters_by_path(Path=path, Recursive=True, WithDecryption=decryption)
params = map(lambda p: _remove_prefix(p, path), response['Parameters'])
return _convert_to_dict(params)


def _join_slashes(paths):
path = '/'.join(filter(lambda e: len(e), map(lambda e: _remove_slashes_on_edge(e), paths)))
if not path:
return '/'
else:
return '/' + path + '/'


def _remove_slashes_on_edge(string):
if _is_led_by_slash(string):
string = string[1:]
if _is_followed_by_slash(string):
string = string[:-1]
return string


def _is_led_by_slash(string):
return string[:1] == '/'


def _is_followed_by_slash(string):
return string[-1:] == '/'


def _remove_prefix(param, prefix):
param['Name'] = param['Name'][len(prefix):]
return param


def _convert_to_dict(params):
return {p['Name']: p['Value'] for p in params}
8 changes: 7 additions & 1 deletion scripts/spec.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ set -eu
here=$(cd $(dirname "$0") && pwd)
project_root=$(cd "${here}/.." && pwd)

spec() {
cd "${project_root}"
python3 setup.py test
cd -
}

set -x

echo "Test"
spec
22 changes: 17 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
here = path.abspath(path.dirname(__file__))

# Get the long description from the README file
with open(path.join(here, 'README.rst'), encoding='utf-8') as f:
long_description = f.read()
readme = path.join(here, 'README.rst')
if path.exists(readme):
with open(readme, encoding='utf-8') as f:
long_description = f.read()
else:
long_description = ''

setup(
# This is the name of your project. The first time you publish this
Expand All @@ -28,7 +32,7 @@
# For a discussion on single-sourcing the version across setup.py and the
# project code, see
# https://packaging.python.org/en/latest/single_source_version.html
version='0.0.3', # Required
version='0.0.4', # Required

# This is a one-line description or tagline of what your project does. This
# corresponds to the "Summary" metadata field:
Expand Down Expand Up @@ -100,7 +104,7 @@
#
# py_modules=["my_module"],
#
packages=find_packages(exclude=['contrib', 'docs', 'tests']), # Required
packages=find_packages(exclude=['spec']), # Required

# If your project only runs on certain Python versions, setting the python
# requires argument to the appropriate PEP 440 version specifier string will
Expand Down Expand Up @@ -156,7 +160,15 @@
# executes the function `main` from this package when invoked:
entry_points={ # Optional
'console_scripts': [
'aws-pspy=aws_paramstore_py:main',
'aws-pspy=aws_paramstore_py:cli',
],
},

# A string naming a unittest.TestCase subclass (or a package or module containing
# one or more of them, or a method of such a subclass), or naming a function that
# can be called with no arguments and returns a unittest.TestSuite. If the named
# suite is a module, and the module has an additional_tests() function, it is called
# and the results are added to the tests to be run. If the named suite is a package,
# any submodules and subpackages are recursively added to the overall test suite.
test_suite='spec',
)
Empty file added spec/__init__.py
Empty file.
54 changes: 54 additions & 0 deletions spec/test_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import unittest
from unittest.mock import patch

import aws_paramstore_py as paramstore


@patch('aws_paramstore_py.main.boto3')
class TestMain(unittest.TestCase):
def test_get(self, mock):
method = mock.client('ssm').get_parameters_by_path
method.return_value = {'Parameters': [
{'Name': '/path/to/params/key1', 'Value': "value1"},
{'Name': '/path/to/params/key2', 'Value': "value2"}
]}

params = paramstore.get('path', 'to', 'params')

method.assert_called_with(Path='/path/to/params/', Recursive=True, WithDecryption=False)
self.assertDictEqual({"key1": "value1", "key2": "value2"}, params)

def test_get_root(self, mock):
method = mock.client('ssm').get_parameters_by_path

paramstore.get()

method.assert_called_with(Path='/', Recursive=True, WithDecryption=False)

def test_get_slash(self, mock):
method = mock.client('ssm').get_parameters_by_path

paramstore.get('/')

method.assert_called_with(Path='/', Recursive=True, WithDecryption=False)

def test_get_leading_slash(self, mock):
method = mock.client('ssm').get_parameters_by_path

paramstore.get('/path/to')

method.assert_called_with(Path='/path/to/', Recursive=True, WithDecryption=False)

def test_get_following_slash(self, mock):
method = mock.client('ssm').get_parameters_by_path

paramstore.get('path/to/')

method.assert_called_with(Path='/path/to/', Recursive=True, WithDecryption=False)

def test_get_with_decryption(self, mock):
method = mock.client('ssm').get_parameters_by_path

paramstore.get('path/to/params', decryption=True)

method.assert_called_with(Path='/path/to/params/', Recursive=True, WithDecryption=True)

0 comments on commit ead5ccf

Please sign in to comment.