diff --git a/CHANGELOG.md b/CHANGELOG.md index f8844012..01a66daa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ All notable changes to this project will be documented in this file. ### Additions - New rules: `PublicELBCheckerRule`, `StackNameMatchesRegexRule`, and `StorageEncryptedRule` - New regex: `REGEX_ALPHANUMERICAL_OR_HYPHEN` to check if stack name only consists of alphanumerical characters and hyphens. +- Config has a few extra methods that should make handling Filters easier ## [1.14.0] ### Additions diff --git a/cfripper/__version__.py b/cfripper/__version__.py index ae85267d..ae0e2c35 100644 --- a/cfripper/__version__.py +++ b/cfripper/__version__.py @@ -1,3 +1,3 @@ -VERSION = (1, 14, 0) +VERSION = (1, 15, 0) __version__ = ".".join(map(str, VERSION)) diff --git a/cfripper/config/config.py b/cfripper/config/config.py index 8bb811fa..61c4fa0c 100644 --- a/cfripper/config/config.py +++ b/cfripper/config/config.py @@ -3,6 +3,7 @@ import logging import sys from collections import defaultdict +from importlib.util import module_from_spec, spec_from_file_location from io import TextIOWrapper from pathlib import Path from typing import DefaultDict, Dict, List @@ -117,6 +118,8 @@ class Config: "directconnect:", "trustedadvisor:", ] + RULES_CONFIG_MODULE_NAME = "__rules_config__" + FILTER_CONFIG_MODULE_NAME = "__filter_config__" def __init__( self, @@ -189,7 +192,7 @@ def load_rules_config_file(self, rules_config_file: TextIOWrapper): try: ext = Path(filename).suffix - module_name = "__rules_config__" + module_name = self.RULES_CONFIG_MODULE_NAME if ext not in [".py", ".pyc"]: raise RuntimeError("Configuration file should have a valid Python extension.") spec = importlib.util.spec_from_file_location(module_name, filename) @@ -205,32 +208,41 @@ def load_rules_config_file(self, rules_config_file: TextIOWrapper): raise def add_filters_from_dir(self, path: str): + self.add_filters(filters=self.get_filters_from_dir(path)) + + @classmethod + def get_filters_from_dir(cls, path: str) -> List[Filter]: + filters = [] + for filename in cls.get_filenames_from_dir(path): + try: + filters.extend(cls.get_filters_from_filename_path(filename)) + except Exception: + logger.exception(f"Failed to read files in path: {path} ({filename})") + raise + return filters + + @classmethod + def get_filenames_from_dir(cls, path: str) -> List[Path]: if not Path(path).is_dir(): raise RuntimeError(f"{path} doesn't exist") - - try: - module_name = "__rules_config__" - filenames = sorted(itertools.chain(Path(path).glob("*.py"), Path(path).glob("*.pyc"))) - for filename in filenames: - spec = importlib.util.spec_from_file_location(module_name, filename.absolute()) - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - spec.loader.exec_module(module) - filters = vars(module).get("FILTERS") - if not filters: - continue - # Validate filters format - RulesFiltersMapping(__root__=filters) - self.add_filters(filters=filters) - logger.debug(f"{filename} loaded") - except Exception: - logger.exception(f"Failed to read files in path: {path}") - raise + filenames = sorted(itertools.chain(Path(path).glob("*.py"), Path(path).glob("*.pyc"))) + return filenames + + @classmethod + def get_filters_from_filename_path(cls, filename: Path) -> List[Filter]: + spec = spec_from_file_location(cls.FILTER_CONFIG_MODULE_NAME, filename.absolute()) + module = module_from_spec(spec) + sys.modules[cls.FILTER_CONFIG_MODULE_NAME] = module + spec.loader.exec_module(module) + filters = vars(module).get("FILTERS") or [] + # Validate filters format + RulesFiltersMapping(__root__=filters) + return filters def add_filters(self, filters: List[Filter]): - for filter in filters: - for rule in filter.rules: - self.rules_filters[rule].append(filter) + for rule_filter in filters: + for rule in rule_filter.rules: + self.rules_filters[rule].append(rule_filter) class RulesConfigMapping(BaseModel):