Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plugin system for Diff* engine #18

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
build
dist
*.egg-info
.idea
47 changes: 23 additions & 24 deletions ssm-diff
Original file line number Diff line number Diff line change
@@ -1,58 +1,57 @@
#!/usr/bin/env python
from __future__ import print_function
from states import *
import states.helpers as helpers

import argparse
import os

from states import *


def configure_endpoints(args):
# configure() returns a DiffBase class (whose constructor may be wrapped in `partial` to pre-configure it)
diff_class = DiffBase.get_plugin(args.engine).configure(args)
return storage.ParameterStore(args.profile, diff_class, paths=args.path), storage.YAMLFile(args.filename, paths=args.path)


def init(args):
r, l = RemoteState(args.profile), LocalState(args.filename)
l.save(r.get(flat=False, paths=args.path))
"""Create a local YAML file from the SSM Parameter Store (per configs in args)"""
remote, local = configure_endpoints(args)
local.save(remote.clone())


def pull(args):
dictfilter = lambda x, y: dict([ (i,x[i]) for i in x if i in set(y) ])
r, l = RemoteState(args.profile), LocalState(args.filename)
diff = helpers.FlatDictDiffer(r.get(paths=args.path), l.get(paths=args.path))
if args.force:
ref_set = diff.changed().union(diff.removed()).union(diff.unchanged())
target_set = diff.added()
else:
ref_set = diff.unchanged().union(diff.removed())
target_set = diff.added().union(diff.changed())
state = dictfilter(diff.ref, ref_set)
state.update(dictfilter(diff.target, target_set))
l.save(helpers.unflatten(state))
"""Update local YAML file with changes in the SSM Parameter Store (per configs in args)"""
remote, local = configure_endpoints(args)
local.save(remote.pull(local.get()))


def apply(args):
r, _, diff = plan(args)

"""Apply local changes to the SSM Parameter Store"""
remote, local = configure_endpoints(args)
print("\nApplying changes...")
try:
r.apply(diff)
remote.push(local.get())
except Exception as e:
print("Failed to apply changes to remote:", e)
print("Done.")


def plan(args):
r, l = RemoteState(args.profile), LocalState(args.filename)
diff = helpers.FlatDictDiffer(r.get(paths=args.path), l.get(paths=args.path))
"""Print a representation of the changes that would be applied to SSM Parameter Store if applied (per config in args)"""
remote, local = configure_endpoints(args)
diff = remote.dry_run(local.get())

if diff.differ:
diff.print_state()
print(DiffBase.describe_diff(diff.plan))
else:
print("Remote state is up to date.")

return r, l, diff


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-f', help='local state yml file', action='store', dest='filename', default='parameters.yml')
parser.add_argument('--path', '-p', action='append', help='filter SSM path')
parser.add_argument('--engine', '-e', help='diff engine to use when interacting with SSM', action='store', dest='engine', default='DiffResolver')
parser.add_argument('--profile', help='AWS profile name', action='store', dest='profile')
subparsers = parser.add_subparsers(dest='func', help='commands')
subparsers.required = True
Expand Down
3 changes: 2 additions & 1 deletion states/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .states import *
from .storage import YAMLFile, ParameterStore
from .engine import DiffBase
154 changes: 154 additions & 0 deletions states/engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import collections
import logging
from functools import partial

from termcolor import colored

from .helpers import add


class DiffMount(type):
"""Metaclass for Diff plugin system"""
# noinspection PyUnusedLocal,PyMissingConstructor
def __init__(cls, *args, **kwargs):
if not hasattr(cls, 'plugins'):
cls.plugins = dict()
else:
cls.plugins[cls.__name__] = cls


class DiffBase(metaclass=DiffMount):
"""Superclass for diff plugins"""
def __init__(self, remote, local):
self.logger = logging.getLogger(self.__module__)
self.remote_flat, self.local_flat = self._flatten(remote), self._flatten(local)
self.remote_set, self.local_set = set(self.remote_flat.keys()), set(self.local_flat.keys())

# noinspection PyUnusedLocal
@classmethod
def get_plugin(cls, name):
if name in cls.plugins:
return cls.plugins[name]

@classmethod
def configure(cls, args):
"""Extract class-specific configurations from CLI args and pre-configure the __init__ method using functools.partial"""
return cls

@classmethod
def _flatten(cls, d, current_path='', sep='/'):
"""Convert a nested dict structure into a "flattened" dict i.e. {"full/path": "value", ...}"""
items = []
for k in d:
new = current_path + sep + k if current_path else k
if isinstance(d[k], collections.MutableMapping):
items.extend(cls._flatten(d[k], new, sep=sep).items())
else:
items.append((sep + new, d[k]))
return dict(items)

@classmethod
def _unflatten(cls, d, sep='/'):
"""Converts a "flattened" dict i.e. {"full/path": "value", ...} into a nested dict structure"""
output = {}
for k in d:
add(
obj=output,
path=k,
value=d[k],
sep=sep,
)
return output

@classmethod
def describe_diff(cls, plan):
"""Return a (multi-line) string describing all differences"""
description = ""
for k, v in plan['add'].items():
# { key: new_value }
description += colored("+", 'green'), "{} = {}".format(k, v) + '\n'

for k in plan['delete']:
# { key: old_value }
description += colored("-", 'red'), k + '\n'

for k, v in plan['change'].items():
# { key: {'old': value, 'new': value} }
description += colored("~", 'yellow'), "{}:\n\t< {}\n\t> {}".format(k, v['old'], v['new']) + '\n'

return description

@property
def plan(self):
"""Returns a `dict` of operations for updating the remote storage i.e. {'add': {...}, 'change': {...}, 'delete': {...}}"""
raise NotImplementedError

def merge(self):
"""Generate a merge of the local and remote dicts, following configurations set during __init__"""
raise NotImplementedError


class DiffResolver(DiffBase):
"""Determines diffs between two dicts, where the remote copy is considered the baseline"""
def __init__(self, remote, local, force=False):
super().__init__(remote, local)
self.intersection = self.remote_set.intersection(self.local_set)
self.force = force

if self.added() or self.removed() or self.changed():
self.differ = True
else:
self.differ = False

@classmethod
def configure(cls, args):
kwargs = {}
if hasattr(args, 'force'):
kwargs['force'] = args.force
return partial(cls, **kwargs)

def added(self):
"""Returns a (flattened) dict of added leaves i.e. {"full/path": value, ...}"""
return self.local_set - self.intersection

def removed(self):
"""Returns a (flattened) dict of removed leaves i.e. {"full/path": value, ...}"""
return self.remote_set - self.intersection

def changed(self):
"""Returns a (flattened) dict of changed leaves i.e. {"full/path": value, ...}"""
return set(k for k in self.intersection if self.remote_flat[k] != self.local_flat[k])

def unchanged(self):
"""Returns a (flattened) dict of unchanged leaves i.e. {"full/path": value, ...}"""
return set(k for k in self.intersection if self.remote_flat[k] == self.local_flat[k])

@property
def plan(self):
return {
'add': {
k: self.local_flat[k] for k in self.added()
},
'delete': {
k: self.remote_flat[k] for k in self.removed()
},
'change': {
k: {'old': self.remote_flat[k], 'new': self.local_flat[k]} for k in self.changed()
}
}

def merge(self):
dictfilter = lambda original, keep_keys: dict([(i, original[i]) for i in original if i in set(keep_keys)])
if self.force:
# Overwrite local changes (i.e. only preserve added keys)
# NOTE: Currently the system cannot tell the difference between a remote delete and a local add
prior_set = self.changed().union(self.removed()).union(self.unchanged())
current_set = self.added()
else:
# Preserve added keys and changed keys
# NOTE: Currently the system cannot tell the difference between a remote delete and a local add
prior_set = self.unchanged().union(self.removed())
current_set = self.added().union(self.changed())
state = dictfilter(original=self.remote_flat, keep_keys=prior_set)
state.update(dictfilter(original=self.local_flat, keep_keys=current_set))
return self._unflatten(state)
64 changes: 4 additions & 60 deletions states/helpers.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,9 @@
from termcolor import colored
from copy import deepcopy
import collections


class FlatDictDiffer(object):
def __init__(self, ref, target):
self.ref, self.target = ref, target
self.ref_set, self.target_set = set(ref.keys()), set(target.keys())
self.isect = self.ref_set.intersection(self.target_set)

if self.added() or self.removed() or self.changed():
self.differ = True
else:
self.differ = False

def added(self):
return self.target_set - self.isect

def removed(self):
return self.ref_set - self.isect

def changed(self):
return set(k for k in self.isect if self.ref[k] != self.target[k])

def unchanged(self):
return set(k for k in self.isect if self.ref[k] == self.target[k])

def print_state(self):
for k in self.added():
print(colored("+", 'green'), "{} = {}".format(k, self.target[k]))

for k in self.removed():
print(colored("-", 'red'), k)

for k in self.changed():
print(colored("~", 'yellow'), "{}:\n\t< {}\n\t> {}".format(k, self.ref[k], self.target[k]))


def flatten(d, pkey='', sep='/'):
items = []
for k in d:
new = pkey + sep + k if pkey else k
if isinstance(d[k], collections.MutableMapping):
items.extend(flatten(d[k], new, sep=sep).items())
else:
items.append((sep + new, d[k]))
return dict(items)


def add(obj, path, value):
parts = path.strip("/").split("/")
def add(obj, path, value, sep='/'):
"""Add value to the `obj` dict at the specified path"""
parts = path.strip(sep).split(sep)
last = len(parts) - 1
for index, part in enumerate(parts):
if index == last:
Expand All @@ -61,7 +15,7 @@ def add(obj, path, value):
def search(state, path):
result = state
for p in path.strip("/").split("/"):
if result.get(p):
if result.clone(p):
result = result[p]
else:
result = {}
Expand All @@ -71,16 +25,6 @@ def search(state, path):
return output


def unflatten(d):
output = {}
for k in d:
add(
obj=output,
path=k,
value=d[k])
return output


def merge(a, b):
if not isinstance(b, dict):
return b
Expand Down
Loading