diff --git a/.github/workflows/release-drafter.yml b/.github/workflows/release-drafter.yml index f0d6f9e..0c3398a 100644 --- a/.github/workflows/release-drafter.yml +++ b/.github/workflows/release-drafter.yml @@ -2,7 +2,7 @@ name: Release drafter on: push: - branches: [main, master, dev] + branches: [main, master, develop] pull_request: types: [opened, reopened, synchronize] diff --git a/doc/howto/message.rst b/doc/howto/message.rst index fc509ac..f113848 100644 --- a/doc/howto/message.rst +++ b/doc/howto/message.rst @@ -107,7 +107,7 @@ simple symmetric one: >>> from oidcmsg.message import Message >>> from cryptojwt.jwk.hmac import SYMKey >>> msg = Message(key='value', another=2) - >>> keys = [SYMKey(key="A1B2C3D4")] + >>> keys = [SYMKey(key="A1B2C3D4E5F6G7H8")] >>> jws = msg.to_jwt(keys, "HS256") >>> print(jws) diff --git a/setup.py b/setup.py index e149b92..1a5b110 100755 --- a/setup.py +++ b/setup.py @@ -63,7 +63,8 @@ def run_tests(self): install_requires=[ "cryptojwt>=1.5.0", "pyOpenSSL", - "filelock>=3.0.12" + "filelock>=3.0.12", + 'pyyaml>=5.1.2' ], zip_safe=False, cmdclass={'test': PyTest}, diff --git a/src/oidcmsg/__init__.py b/src/oidcmsg/__init__.py index a97c41f..a3f9a90 100755 --- a/src/oidcmsg/__init__.py +++ b/src/oidcmsg/__init__.py @@ -1,7 +1,8 @@ __author__ = "Roland Hedberg" -__version__ = "1.3.1" +__version__ = "1.3.2" import os +from typing import Dict VERIFIED_CLAIM_PREFIX = "__verified" @@ -33,8 +34,14 @@ def proper_path(path): return path -# This is for adding a base path to path specified in a configuration -def add_base_path(conf, item_paths, base_path): +def add_base_path(conf: Dict[str, str], item_paths: dict, base_path: str): + """ + This is for adding a base path to path specified in a configuration + + :param conf: Configuration + :param item_paths: The relative item path + :param base_path: An absolute path to add to the relative + """ for section, items in item_paths.items(): if section == "": part = conf diff --git a/src/oidcmsg/configure.py b/src/oidcmsg/configure.py new file mode 100644 index 0000000..715df32 --- /dev/null +++ b/src/oidcmsg/configure.py @@ -0,0 +1,161 @@ +import importlib +import json +import logging +import os +from typing import Dict +from typing import List +from typing import Optional + +from oidcmsg.logging import configure_logging +from oidcmsg.util import load_yaml_config + +DEFAULT_FILE_ATTRIBUTE_NAMES = ['server_key', 'server_cert', 'filename', 'template_dir', + 'private_path', 'public_path', 'db_file'] + +URIS = ["redirect_uris", 'issuer', 'base_url'] + + +def lower_or_upper(config, param, default=None): + res = config.get(param.lower(), default) + if not res: + res = config.get(param.upper(), default) + return res + + +def add_base_path(conf: dict, base_path: str, file_attributes: List[str]): + for key, val in conf.items(): + if key in file_attributes: + if val.startswith("/"): + continue + elif val == "": + conf[key] = "./" + val + else: + conf[key] = os.path.join(base_path, val) + if isinstance(val, dict): + conf[key] = add_base_path(val, base_path, file_attributes) + + return conf + + +def set_domain_and_port(conf: dict, uris: List[str], domain: str, port: int): + for key, val in conf.items(): + if key in uris: + if not val: + continue + + if isinstance(val, list): + _new = [v.format(domain=domain, port=port) for v in val] + else: + _new = val.format(domain=domain, port=port) + conf[key] = _new + elif isinstance(val, dict): + conf[key] = set_domain_and_port(val, uris, domain, port) + return conf + + +class Base: + """ Configuration base class """ + + def __init__(self, + conf: Dict, + base_path: str = '', + file_attributes: Optional[List[str]] = None, + ): + + if file_attributes is None: + file_attributes = DEFAULT_FILE_ATTRIBUTE_NAMES + + if base_path and file_attributes: + # this adds a base path to all paths in the configuration + add_base_path(conf, base_path, file_attributes) + + def __getitem__(self, item): + if item in self.__dict__: + return self.__dict__[item] + else: + raise KeyError + + def get(self, item, default=None): + return getattr(self, item, default) + + def __contains__(self, item): + return item in self.__dict__ + + def items(self): + for key in self.__dict__: + if key.startswith('__') and key.endswith('__'): + continue + yield key, getattr(self, key) + + def extend(self, entity_conf, conf, base_path, file_attributes, domain, port): + for econf in entity_conf: + _path = econf.get("path") + _cnf = conf + if _path: + for step in _path: + _cnf = _cnf[step] + _attr = econf["attr"] + _cls = econf["class"] + setattr(self, _attr, + _cls(_cnf, base_path=base_path, file_attributes=file_attributes, + domain=domain, port=port)) + + +class Configuration(Base): + """Server Configuration""" + + def __init__(self, + conf: Dict, + base_path: str = '', + entity_conf: Optional[List[dict]] = None, + file_attributes: Optional[List[str]] = None, + domain: Optional[str] = "", + port: Optional[int] = 0, + ): + Base.__init__(self, conf, base_path=base_path, file_attributes=file_attributes) + + log_conf = conf.get('logging') + if log_conf: + self.logger = configure_logging(config=log_conf).getChild(__name__) + else: + self.logger = logging.getLogger('oidcrp') + + self.web_conf = lower_or_upper(conf, "webserver") + + # entity info + if not domain: + domain = conf.get("domain", "127.0.0.1") + + if not port: + port = conf.get("port", 80) + + if entity_conf: + self.extend(entity_conf=entity_conf, conf=conf, base_path=base_path, + file_attributes=file_attributes, domain=domain, port=port) + + +def create_from_config_file(cls, + filename: str, + base_path: Optional[str] = '', + entity_conf: Optional[List[dict]] = None, + file_attributes: Optional[List[str]] = None, + domain: Optional[str] = "", + port: Optional[int] = 0): + if filename.endswith(".yaml"): + """Load configuration as YAML""" + _cnf = load_yaml_config(filename) + elif filename.endswith(".json"): + _str = open(filename).read() + _cnf = json.loads(_str) + elif filename.endswith(".py"): + head, tail = os.path.split(filename) + tail = tail[:-3] + module = importlib.import_module(tail) + _cnf = getattr(module, "CONFIG") + else: + raise ValueError("Unknown file type") + + return cls(_cnf, + entity_conf=entity_conf, + base_path=base_path, file_attributes=file_attributes, + domain=domain, port=port) diff --git a/src/oidcmsg/impexp.py b/src/oidcmsg/impexp.py index c0b01c5..90d96ee 100644 --- a/src/oidcmsg/impexp.py +++ b/src/oidcmsg/impexp.py @@ -47,10 +47,14 @@ def dump(self, exclude_attributes: Optional[List[str]] = None) -> dict: info[attr] = self.dump_attr(cls, item, exclude_attributes) - for attr, d in self.special_load_dump.items(): + for attr, func in self.special_load_dump.items(): item = getattr(self, attr, None) if item: - info[attr] = d["dump"](item, exclude_attributes=exclude_attributes) + if "dump" in func: + info[attr] = func["dump"](item, exclude_attributes=exclude_attributes) + else: + cls = self.parameter[attr] + info[attr] = self.dump_attr(cls, item, exclude_attributes) return info @@ -127,7 +131,11 @@ def load(self, item: dict, init_args: Optional[dict] = None, load_args: Optional for attr, func in self.special_load_dump.items(): if attr in item: - setattr(self, attr, func["load"](item[attr], **_kwargs)) + if "load" in func: + setattr(self, attr, func["load"](item[attr], **_kwargs)) + else: + cls = self.parameter[attr] + setattr(self, attr, self.load_attr(cls, item[attr], **_kwargs)) self.local_load_adjustments(**_load_args) return self diff --git a/src/oidcmsg/logging.py b/src/oidcmsg/logging.py new file mode 100755 index 0000000..c628e07 --- /dev/null +++ b/src/oidcmsg/logging.py @@ -0,0 +1,52 @@ +"""Common logging functions""" +import logging +import os +from logging.config import dictConfig +from typing import Optional + +import yaml + +LOGGING_CONF = 'logging.yaml' + +LOGGING_DEFAULT = { + 'version': 1, + 'formatters': { + 'default': { + 'format': '%(asctime)s %(name)s %(levelname)s %(message)s' + } + }, + 'handlers': { + 'default': { + 'class': 'logging.StreamHandler', + 'formatter': 'default' + } + }, + 'root': { + 'handlers': ['default'], + 'level': 'INFO' + } +} + + +def configure_logging(debug: Optional[bool] = False, + config: Optional[dict] = None, + filename: Optional[str] = LOGGING_CONF) -> logging.Logger: + """Configure logging""" + + if config is not None: + config_dict = config + config_source = 'dictionary' + elif filename is not None and os.path.exists(filename): + with open(filename, "rt") as file: + config_dict = yaml.load(file) + config_source = 'file' + else: + config_dict = LOGGING_DEFAULT + config_source = 'default' + + if debug: + config_dict['root']['level'] = 'DEBUG' + + dictConfig(config_dict) + logging.debug("Configured logging using %s", config_source) + return logging.getLogger() diff --git a/src/oidcmsg/util.py b/src/oidcmsg/util.py new file mode 100644 index 0000000..3aa6275 --- /dev/null +++ b/src/oidcmsg/util.py @@ -0,0 +1,20 @@ +import secrets + +import yaml + + +def rndstr(size=16): + """ + Returns a string of random url safe characters + + :param size: The length of the string + :return: string + """ + return secrets.token_urlsafe(size) + + +def load_yaml_config(filename): + """Load a YAML configuration file.""" + with open(filename, "rt", encoding='utf-8') as file: + config_dict = yaml.safe_load(file) + return config_dict diff --git a/tests/entity_conf.json b/tests/entity_conf.json new file mode 100644 index 0000000..b7312e4 --- /dev/null +++ b/tests/entity_conf.json @@ -0,0 +1,29 @@ +{ + "port": 8090, + "domain": "127.0.0.1", + "base_url": "https://{domain}:{port}", + "httpc_params": { + "verify": false + }, + "keys": { + "private_path": "private/jwks.json", + "key_defs": [ + { + "type": "RSA", + "key": "", + "use": [ + "sig" + ] + }, + { + "type": "EC", + "crv": "P-256", + "use": [ + "sig" + ] + } + ], + "public_path": "static/jwks.json", + "read_only": false + } +} diff --git a/tests/entity_conf.py b/tests/entity_conf.py new file mode 100644 index 0000000..b7c3ba7 --- /dev/null +++ b/tests/entity_conf.py @@ -0,0 +1,29 @@ +CONFIG = { + "port": 8090, + "domain": "127.0.0.1", + "base_url": "https://{domain}:{port}", + "httpc_params": { + "verify": False + }, + "keys": { + "private_path": "private/jwks.json", + "key_defs": [ + { + "type": "RSA", + "key": "", + "use": [ + "sig" + ] + }, + { + "type": "EC", + "crv": "P-256", + "use": [ + "sig" + ] + } + ], + "public_path": "static/jwks.json", + "read_only": False + } +} diff --git a/tests/server_conf.json b/tests/server_conf.json new file mode 100644 index 0000000..d0e7858 --- /dev/null +++ b/tests/server_conf.json @@ -0,0 +1,69 @@ +{ + "logging": { + "version": 1, + "disable_existing_loggers": false, + "root": { + "handlers": [ + "console", + "file" + ], + "level": "DEBUG" + }, + "loggers": { + "idp": { + "level": "DEBUG" + } + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "stream": "ext://sys.stdout", + "formatter": "default" + }, + "file": { + "class": "logging.FileHandler", + "filename": "debug.log", + "formatter": "default" + } + }, + "formatters": { + "default": { + "format": "%(asctime)s %(name)s %(levelname)s %(message)s" + } + } + }, + "port": 8090, + "domain": "127.0.0.1", + "base_url": "https://{domain}:{port}", + "httpc_params": { + "verify": false + }, + "keys": { + "private_path": "private/jwks.json", + "key_defs": [ + { + "type": "RSA", + "key": "", + "use": [ + "sig" + ] + }, + { + "type": "EC", + "crv": "P-256", + "use": [ + "sig" + ] + } + ], + "public_path": "static/jwks.json", + "read_only": false + }, + "webserver": { + "port": 8090, + "domain": "127.0.0.1", + "server_cert": "certs/cert.pem", + "server_key": "certs/key.pem", + "debug": true + } +} diff --git a/tests/test_20_config.py b/tests/test_20_config.py new file mode 100644 index 0000000..e5ac3a8 --- /dev/null +++ b/tests/test_20_config.py @@ -0,0 +1,78 @@ +import os +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +import pytest + +from oidcmsg.configure import Base +from oidcmsg.configure import Configuration +from oidcmsg.configure import create_from_config_file +from oidcmsg.configure import lower_or_upper +from oidcmsg.configure import set_domain_and_port +from oidcmsg.util import rndstr + +_dirname = os.path.dirname(os.path.abspath(__file__)) + +URIS = ["base_url"] + + +class EntityConfiguration(Base): + def __init__(self, + conf: Dict, + entity_conf: Optional[Any] = None, + base_path: Optional[str] = '', + domain: Optional[str] = "", + port: Optional[int] = 0, + file_attributes: Optional[List[str]] = None, + uris: Optional[List[str]] = None + ): + + Base.__init__(self, conf, base_path=base_path, file_attributes=file_attributes) + + self.keys = lower_or_upper(conf, 'keys') + + if not domain: + domain = conf.get("domain", "127.0.0.1") + + if not port: + port = conf.get("port", 80) + + if uris is None: + uris = URIS + conf = set_domain_and_port(conf, uris, domain, port) + + self.hash_seed = lower_or_upper(conf, 'hash_seed', rndstr(32)) + self.base_url = conf.get("base_url") + self.httpc_params = conf.get("httpc_params", {"verify": False}) + + +def test_server_config(): + configuration = create_from_config_file(Configuration, + entity_conf=[ + {"class": EntityConfiguration, "attr": "entity"}], + filename=os.path.join(_dirname, 'server_conf.json'), + base_path=_dirname) + assert configuration + assert set(configuration.web_conf.keys()) == {'port', 'domain', 'server_cert', 'server_key', + 'debug'} + + entity_config = configuration.entity + assert entity_config.base_url == "https://127.0.0.1:8090" + assert entity_config.httpc_params == {"verify": False} + + +@pytest.mark.parametrize("filename", ['entity_conf.json', 'entity_conf.py']) +def test_entity_config(filename): + configuration = create_from_config_file(EntityConfiguration, + filename=os.path.join(_dirname, filename), + base_path=_dirname) + assert configuration + + assert configuration.base_url == "https://127.0.0.1:8090" + assert configuration.httpc_params == {"verify": False} + assert configuration['keys'] + ni = dict(configuration.items()) + assert len(ni) == 4 + assert set(ni.keys()) == {'keys', 'base_url', 'httpc_params', 'hash_seed'}