diff --git a/src/idpyoidc/claims.py b/src/idpyoidc/claims.py index afa6680f..45f54bea 100644 --- a/src/idpyoidc/claims.py +++ b/src/idpyoidc/claims.py @@ -18,6 +18,7 @@ logger = logging.getLogger(__name__) + def claims_dump(info, exclude_attributes): return {qualified_name(info.__class__): info.dump(exclude_attributes=exclude_attributes)} @@ -321,3 +322,29 @@ def get_client_metadata(self, return {entity_type: metadata} else: return metadata + + def get_registration_metadata(self, + entity_type: Optional[str] = "", + metadata_schema: Optional[Message] = None, + extra_claims: Optional[List[str]] = None, + supported: Optional[dict] = None, + **kwargs): + + metadata = self.prefer + + # the claims that can appear in the metadata + if metadata_schema: + attr = list(metadata_schema.c_param.keys()) + else: + attr = [] + + if extra_claims: + attr.extend(extra_claims) + + if attr: + metadata = {k: v for k, v in metadata.items() if k in attr} + + if entity_type: + return {entity_type: metadata} + else: + return metadata diff --git a/src/idpyoidc/client/claims/__init__.py b/src/idpyoidc/client/claims/__init__.py index 1427005b..d495881b 100644 --- a/src/idpyoidc/client/claims/__init__.py +++ b/src/idpyoidc/client/claims/__init__.py @@ -13,6 +13,8 @@ def get_client_authn_methods(): class Claims(claims.Claims): + _supports = {} + def get_base_url(self, configuration: dict, entity_id: Optional[str] = ""): _base = configuration.get("base_url") if not _base: @@ -54,9 +56,10 @@ def get_jwks(self, keyjar): # if only one key under the id == "", that key being a SYMKey I assume it's # and I have a client_secret then don't publish a JWKS if ( - len(_own_keys) == 1 - and isinstance(_own_keys[0], SYMKey) - and self.prefer["client_secret"] + len(_own_keys) == 1 + and isinstance(_own_keys[0], SYMKey) + and self.prefer["client_secret"] + and self.prefer.get("client_secret", None) ): pass else: diff --git a/src/idpyoidc/client/claims/oauth2.py b/src/idpyoidc/client/claims/oauth2.py index 16e90475..543d619c 100644 --- a/src/idpyoidc/client/claims/oauth2.py +++ b/src/idpyoidc/client/claims/oauth2.py @@ -2,20 +2,34 @@ from idpyoidc.client import claims from idpyoidc.transform import create_registration_request +from idpyoidc.transform import create_registration_request +REGISTER2PREFERRED = { + "scope": "scopes_supported", + "token_endpoint_auth_signing_alg": "token_endpoint_auth_signing_alg_values_supported", + "response_types": "response_types_supported", + # "response_modes": "response_modes_supported", + "grant_types": "grant_types_supported", + "token_endpoint_auth_method": "token_endpoint_auth_methods_supported", + "token_auth_signing_algs": "token_auth_signing_algs_supported", + # 'ui_locales': 'ui_locales_supported', +} class Claims(claims.Claims): + register2preferred = REGISTER2PREFERRED + _supports = { "redirect_uris": None, "grant_types_supported": ["authorization_code", "implicit", "refresh_token"], "response_types_supported": ["code"], "client_id": None, - "client_secret": None, "client_name": None, + "client_secret": None, "client_uri": None, "logo_uri": None, + "scope": None, "contacts": None, - "scopes_supported": [], + # "scopes_supported": [], "tos_uri": None, "policy_uri": None, "jwks_uri": None, diff --git a/src/idpyoidc/client/claims/oauth2resource.py b/src/idpyoidc/client/claims/oauth2resource.py index 537e1391..2bec3c8e 100644 --- a/src/idpyoidc/client/claims/oauth2resource.py +++ b/src/idpyoidc/client/claims/oauth2resource.py @@ -2,7 +2,7 @@ from idpyoidc.client import claims from idpyoidc.message.oauth2 import OAuthProtectedResourceRequest -from idpyoidc.client.claims.transform import array_or_singleton +from idpyoidc.transform import array_or_singleton class Claims(claims.Claims): _supports = { diff --git a/src/idpyoidc/client/claims/oidc.py b/src/idpyoidc/client/claims/oidc.py index d2b1b0b0..85cf9242 100644 --- a/src/idpyoidc/client/claims/oidc.py +++ b/src/idpyoidc/client/claims/oidc.py @@ -95,13 +95,13 @@ def __init__(self, prefer: Optional[dict] = None, callback_path: Optional[dict] client_claims.Claims.__init__(self, prefer=prefer, callback_path=callback_path) def verify_rules(self, supports): - if self.get_preference("request_parameter_supported") and self.get_preference( - "request_uri_parameter_supported" - ): - raise ValueError( - "You have to chose one of 'request_parameter_supported' and " - "'request_uri_parameter_supported'. You can't have both." - ) + # if self.get_preference("request_parameter_supported") and self.get_preference( + # "request_uri_parameter_supported" + # ): + # raise ValueError( + # "You have to chose one of 'request_parameter_supported' and " + # "'request_uri_parameter_supported'. You can't have both." + # ) if self.get_preference("request_parameter_supported") or self.get_preference( "request_uri_parameter_supported" diff --git a/src/idpyoidc/client/client_auth.py b/src/idpyoidc/client/client_auth.py index a8830cd4..baf03d91 100755 --- a/src/idpyoidc/client/client_auth.py +++ b/src/idpyoidc/client/client_auth.py @@ -10,6 +10,7 @@ from cryptojwt.jws.utils import alg2keytype from cryptojwt.utils import importer +from idpyoidc.client.request_object import construct_request_parameter from idpyoidc.defaults import DEF_SIGN_ALG from idpyoidc.defaults import JWT_BEARER from idpyoidc.message import Message @@ -31,6 +32,7 @@ DEFAULT_ACCESS_TOKEN_TYPE = "Bearer" + class AuthnFailure(Exception): """Unspecified Authentication failure""" @@ -46,7 +48,7 @@ def assertion_jwt(client_id, keys, audience, algorithm, lifetime=600): :param client_id: The Client ID :param keys: Signing keys - :param audience: Who is the receivers for this assertion + :param audience: Who's the receivers for this assertion :param algorithm: Signing algorithm :param lifetime: The lifetime of the signed Json Web Token :return: A Signed Json Web Token @@ -628,6 +630,12 @@ def get_signing_key_from_keyjar(self, algorithm, keyjar): return keyjar.get_signing_key(alg2keytype(algorithm), "", alg=algorithm) +class RequestParam(ClientAuthnMethod): + def construct(self, request, service=None, http_args=None, **kwargs): + request_object = construct_request_parameter(service, request, **kwargs) + request["request"] = request_object + + # Map from client authentication identifiers to corresponding class CLIENT_AUTHN_METHOD = { "client_secret_basic": ClientSecretBasic, @@ -637,6 +645,7 @@ def get_signing_key_from_keyjar(self, algorithm, keyjar): "client_secret_jwt": ClientSecretJWT, "private_key_jwt": PrivateKeyJWT, # "client_notification_authn": ClientNotificationAuthn + "request_param": RequestParam } TYPE_METHOD = [(JWT_BEARER, JWSAuthnMethod)] diff --git a/src/idpyoidc/client/entity.py b/src/idpyoidc/client/entity.py index 197d5d73..2a1b0a67 100644 --- a/src/idpyoidc/client/entity.py +++ b/src/idpyoidc/client/entity.py @@ -103,8 +103,9 @@ def __init__( if config is None: config = {} + # Client ID is set through configuration or at registration _id = config.get("client_id") - self.client_id = self.entity_id = entity_id or config.get("entity_id", _id) + self.entity_id = entity_id or config.get("entity_id", _id) Unit.__init__( self, @@ -114,7 +115,7 @@ def __init__( httpc_params=httpc_params, config=config, key_conf=key_conf, - client_id=self.client_id, + client_id=_id, ) if services: diff --git a/src/idpyoidc/client/oauth2/__init__.py b/src/idpyoidc/client/oauth2/__init__.py index 620608b0..0ecfd79f 100755 --- a/src/idpyoidc/client/oauth2/__init__.py +++ b/src/idpyoidc/client/oauth2/__init__.py @@ -6,6 +6,7 @@ from cryptojwt.key_jar import KeyJar + from idpyoidc.client.entity import Entity from idpyoidc.client.exception import ConfigurationError from idpyoidc.client.exception import OidcServiceError @@ -254,12 +255,13 @@ def parse_request_response(self, service, reqresp, response_body_type="", state= if reqresp.status_code in SUCCESSFUL: logger.debug('response_body_type: "{}"'.format(response_body_type)) + _content_type = reqresp.headers.get("content-type") _deser_method = get_deserialization_method(reqresp) - if _deser_method != response_body_type: + if _content_type != response_body_type: logger.warning( "Not the body type I expected: {} != {}".format( - _deser_method, response_body_type + _content_type, response_body_type ) ) if _deser_method in ["json", "jwt", "urlencoded"]: diff --git a/src/idpyoidc/client/oauth2/add_on/dpop.py b/src/idpyoidc/client/oauth2/add_on/dpop.py index d8a058ef..9b728837 100644 --- a/src/idpyoidc/client/oauth2/add_on/dpop.py +++ b/src/idpyoidc/client/oauth2/add_on/dpop.py @@ -1,8 +1,10 @@ +import base64 import logging import uuid from hashlib import sha256 from typing import Optional +from cryptojwt import as_unicode from cryptojwt.jwk.jwk import key_from_jwk_dict from cryptojwt.jws.jws import factory from cryptojwt.jws.jws import JWS @@ -149,7 +151,7 @@ def dpop_header( } if token: - header_dict["ath"] = sha256(token.encode("utf8")).hexdigest() + header_dict["ath"] = as_unicode(base64.urlsafe_b64encode(sha256(token.encode("utf8")).digest())) if nonce: header_dict["nonce"] = nonce @@ -168,14 +170,19 @@ def dpop_header( def add_support(services, dpop_signing_alg_values_supported, with_dpop_header=None): """ - Add the necessary pieces to make pushed authorization happen. + Add the necessary pieces to make DPoP happen. :param services: A dictionary with all the services the client has access to. :param signing_algorithms: Allowed signing algorithms, there is no default algorithms + :param dpop_signing_alg_values_supported: Allowed signing algorithms, there is no default algorithms + :param with_dpop_header: If a services should add a DPoP header to a request """ - # Access token request should use DPoP header _service = services["accesstoken"] + if with_dpop_header is None: + with_dpop_header = ["accesstoken", "userinfo"] + _service = services[with_dpop_header[0]] + # Add to Context _context = _service.upstream_get("context") _algs_supported = [ alg for alg in dpop_signing_alg_values_supported if alg in get_signing_algs() @@ -186,20 +193,8 @@ def add_support(services, dpop_signing_alg_values_supported, with_dpop_header=No } _context.set_preference("dpop_signing_alg_values_supported", _algs_supported) - _service.construct_extra_headers.append(dpop_header) - - # The same for userinfo requests - _userinfo_service = services.get("userinfo") - if _userinfo_service: - _userinfo_service.construct_extra_headers.append(dpop_header) - # To be backward compatible - if with_dpop_header is None: - with_dpop_header = ["userinfo"] - - # Add dpop HTTP header to these + # Add dpop HTTP header to requests by these services for _srv in with_dpop_header: - if _srv == "accesstoken": - continue _service = services.get(_srv) if _service: _service.construct_extra_headers.append(dpop_header) diff --git a/src/idpyoidc/client/oauth2/add_on/jar.py b/src/idpyoidc/client/oauth2/add_on/jar.py index 209c0627..3c2b83f7 100644 --- a/src/idpyoidc/client/oauth2/add_on/jar.py +++ b/src/idpyoidc/client/oauth2/add_on/jar.py @@ -1,11 +1,10 @@ import logging from typing import Optional +from idpyoidc.client.request_object import construct_request_parameter + from idpyoidc import alg_info -from idpyoidc.client.oidc.utils import construct_request_uri -from idpyoidc.client.oidc.utils import request_object_encryption -from idpyoidc.message.oidc import make_openid_request -from idpyoidc.time_util import utc_time_sans_frac +from idpyoidc.client.util import construct_request_uri logger = logging.getLogger(__name__) @@ -34,94 +33,6 @@ def store_request_on_file(service, req, **kwargs): return _webname -def get_request_object_signing_alg(service, **kwargs): - alg = "" - for arg in ["request_object_signing_alg", "algorithm"]: - try: # Trumps everything - alg = kwargs[arg] - except KeyError: - pass - else: - break - - if not alg: - _context = service.upstream_get("context") - alg = _context.add_on["jar"].get("request_object_signing_alg") - if alg is None: - alg = "RS256" - return alg - - -def construct_request_parameter(service, req, audience=None, **kwargs): - """Construct a request parameter""" - alg = get_request_object_signing_alg(service, **kwargs) - kwargs["request_object_signing_alg"] = alg - - _context = service.upstream_get("context") - if "keys" not in kwargs and alg and alg != "none": - kwargs["keys"] = service.upstream_get("attribute", "keyjar") - - if alg == "none": - kwargs["keys"] = [] - - # This is the issuer of the JWT, that is me ! - _issuer = kwargs.get("issuer") - if _issuer is None: - kwargs["issuer"] = _context.get_client_id() - - if kwargs.get("recv") is None: - try: - kwargs["recv"] = _context.provider_info["issuer"] - except KeyError: - kwargs["recv"] = _context.issuer - - try: - del kwargs["service"] - except KeyError: - pass - - _jar_conf = _context.add_on["jar"] - expires_in = _jar_conf.get("expires_in", DEFAULT_EXPIRES_IN) - if expires_in: - req["exp"] = utc_time_sans_frac() + int(expires_in) - - if _jar_conf.get("with_jti", False): - kwargs["with_jti"] = True - - _enc_enc = _jar_conf.get("request_object_encryption_enc", "") - if _enc_enc: - kwargs["request_object_encryption_enc"] = _enc_enc - kwargs["request_object_encryption_alg"] = _jar_conf.get("request_object_encryption_alg") - - # Filter out only the arguments I want - _mor_args = { - k: kwargs[k] - for k in [ - "keys", - "issuer", - "request_object_signing_alg", - "recv", - "with_jti", - "lifetime", - ] - if k in kwargs - } - - if audience: - _mor_args["aud"] = audience - - _req_jwt = make_openid_request(req, **_mor_args) - - if "target" not in kwargs: - kwargs["target"] = _context.provider_info.get("issuer", _context.issuer) - - # Should the request be encrypted - _req_jwte = request_object_encryption( - _req_jwt, _context, service.upstream_get("attribute", "keyjar"), **kwargs - ) - return _req_jwte - - def jar_post_construct(request_args, service, **kwargs): """ Modify the request arguments. diff --git a/src/idpyoidc/client/oauth2/add_on/par.py b/src/idpyoidc/client/oauth2/add_on/par.py index afa94058..353dc7b6 100644 --- a/src/idpyoidc/client/oauth2/add_on/par.py +++ b/src/idpyoidc/client/oauth2/add_on/par.py @@ -3,6 +3,8 @@ from cryptojwt.utils import importer from idpyoidc.client.client_auth import CLIENT_AUTHN_METHOD +from idpyoidc.client.oauth2.utils import set_request_object +from idpyoidc.client.service import Service from idpyoidc.message import Message from idpyoidc.message.oauth2 import JWTSecuredAuthorizationRequest from idpyoidc.server.util import execute @@ -13,7 +15,7 @@ HTTP_METHOD = "POST" -def push_authorization(request_args, service, **kwargs): +def push_authorization(request_args: Message, service: Service, **kwargs): """ :param request_args: All the request arguments as a AuthorizationRequest instance :param service: The service to which this post construct method is applied. @@ -50,17 +52,31 @@ def push_authorization(request_args, service, **kwargs): # construct the message body _body = request_args.to_urlencoded() + if isinstance(request_args, Message): + _required_params = request_args.to_dict() + else: + _required_params = request_args + + _add_request_object = kwargs.get("add_request_object", False) + if _add_request_object: + _required_params["request"] = set_request_object(service, request_args) + + _req = service.msg_type(**_required_params) + _body = _req.to_urlencoded() _http_client = method_args.get("http_client", None) if not _http_client: _http_client = service.upstream_get("unit").httpc _httpc_params = service.upstream_get("unit").httpc_params + _par_endpoint = kwargs.get("pushed_authorization_request_endpoint", None) + if not _par_endpoint: + _par_endpoint = _context.provider_info["pushed_authorization_request_endpoint"] # Send it to the Pushed Authorization Request Endpoint using POST resp = _http_client( method=HTTP_METHOD, - url=_context.provider_info["pushed_authorization_request_endpoint"], + url=_par_endpoint, data=_body, headers=_headers, **_httpc_params @@ -73,10 +89,7 @@ def push_authorization(request_args, service, **kwargs): _req[param] = request_args.get(param) request_args = _req else: - raise ConnectionError( - f"Could not connect to " - f'{_context.provider_info["pushed_authorization_request_endpoint"]}' - ) + raise ConnectionError(f"Could not connect to {_par_endpoint}") return request_args diff --git a/src/idpyoidc/client/oauth2/authorization.py b/src/idpyoidc/client/oauth2/authorization.py index 9d85f1fd..04ae98d6 100644 --- a/src/idpyoidc/client/oauth2/authorization.py +++ b/src/idpyoidc/client/oauth2/authorization.py @@ -31,7 +31,7 @@ class Authorization(Service): _supports = { "response_types_supported": ["code"], - "response_modes_supported": ["query", "fragment"], + "grant_types": None } _callback_path = { diff --git a/src/idpyoidc/client/oauth2/pushed_authorization.py b/src/idpyoidc/client/oauth2/pushed_authorization.py new file mode 100644 index 00000000..20eb299d --- /dev/null +++ b/src/idpyoidc/client/oauth2/pushed_authorization.py @@ -0,0 +1,89 @@ +"""The service that talks to the OAuth2 Authorization endpoint.""" +import logging + +from idpyoidc.client.oauth2.utils import get_state_parameter +from idpyoidc.client.oauth2.utils import pre_construct_pick_redirect_uri +from idpyoidc.client.oauth2.utils import set_request_object +from idpyoidc.client.oauth2.utils import set_state_parameter +from idpyoidc.client.service import Service +from idpyoidc.exception import MissingParameter +from idpyoidc.message import oauth2 +from idpyoidc.message.oauth2 import ResponseMessage +from idpyoidc.time_util import time_sans_frac + +LOGGER = logging.getLogger(__name__) + + +class PushedAuthorization(Service): + """The service that talks to the OAuth2 Pushed Authorization endpoint.""" + + msg_type = oauth2.PushedAuthorizationRequest + response_cls = oauth2.PushedAuthorizationResponse + error_msg = ResponseMessage + endpoint_name = "pushed_authorization_request_endpoint" + service_name = "pushed_authorization" + response_body_type = "json" + http_method = "POST" + + _supports = { + "response_types_supported": ["code"], + "grant_types": None + } + + def __init__(self, upstream_get, conf=None): + Service.__init__(self, upstream_get, conf=conf) + self.pre_construct.extend([pre_construct_pick_redirect_uri, set_state_parameter]) + self.post_construct.append(self.store_auth_request) + + def add_(self, request_args=None, **kwargs): + _add_request_object = kwargs.get("add_request_object", False) + if _add_request_object: + request_args["request"] = set_request_object(self, request_args) + + def update_service_context(self, resp, key="", **kwargs): + if "expires_in" in resp: + resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) + self.upstream_get("context").cstate.update(key, resp) + + def store_auth_request(self, request_args=None, **kwargs): + """Store the authorization request in the state DB.""" + _key = get_state_parameter(request_args, kwargs) + self.upstream_get("context").cstate.update(_key, request_args) + return request_args + + def gather_request_args(self, **kwargs): + ar_args = Service.gather_request_args(self, **kwargs) + + if "redirect_uri" not in ar_args: + try: + ar_args["redirect_uri"] = self.upstream_get("context").get_usage("redirect_uris")[0] + except (KeyError, AttributeError): + raise MissingParameter("redirect_uri") + + return ar_args + + def post_parse_response(self, response, **kwargs): + """ + Add scope claim to response, from the request, if not present in the + response + + :param response: The response + :param kwargs: Extra Keyword arguments + :return: A possibly augmented response + """ + + if "scope" not in response: + try: + _key = kwargs["state"] + except KeyError: + pass + else: + if _key: + item = self.upstream_get("context").cstate.get_set( + _key, message=oauth2.AuthorizationRequest + ) + try: + response["scope"] = item["scope"] + except KeyError: + pass + return response diff --git a/src/idpyoidc/client/oauth2/registration.py b/src/idpyoidc/client/oauth2/registration.py index 19da4982..ba2ecab0 100644 --- a/src/idpyoidc/client/oauth2/registration.py +++ b/src/idpyoidc/client/oauth2/registration.py @@ -4,6 +4,7 @@ from idpyoidc.client.entity import response_types_to_grant_types from idpyoidc.client.service import Service +from idpyoidc.key_import import store_under_other_id from idpyoidc.message import oauth2 from idpyoidc.message.oauth2 import ResponseMessage @@ -75,7 +76,7 @@ def update_service_context(self, resp, key="", **kwargs): _keyjar = self.upstream_get("attribute", "keyjar") if _keyjar: if _client_id not in _keyjar: - _keyjar.import_jwks(_keyjar.export_jwks(True, ""), issuer_id=_client_id) + _keyjar = store_under_other_id(_keyjar, "", _client_id, True) _client_secret = _context.get_usage("client_secret") if _client_secret: if not _keyjar: diff --git a/src/idpyoidc/client/oauth2/stand_alone_client.py b/src/idpyoidc/client/oauth2/stand_alone_client.py index 8652f56d..c456176e 100644 --- a/src/idpyoidc/client/oauth2/stand_alone_client.py +++ b/src/idpyoidc/client/oauth2/stand_alone_client.py @@ -18,6 +18,8 @@ from idpyoidc.exception import MessageException from idpyoidc.exception import MissingRequiredAttribute from idpyoidc.exception import NotForMe +from idpyoidc.key_import import add_kb +from idpyoidc.key_import import import_jwks_from_file from idpyoidc.message import Message from idpyoidc.message.oauth2 import ResponseMessage from idpyoidc.message.oauth2 import is_error_message @@ -90,10 +92,10 @@ def do_provider_info( elif typ == "file": for kty, _name in _spec.items(): if kty == "jwks": - _kj.import_jwks_from_file(_name, _context.get("issuer")) + _kj = import_jwks_from_file(_kj, _name, _context.get("issuer")) elif kty == "rsa": # PEM file _kb = keybundle_from_local_file(_name, "der", ["sig"]) - _kj.add_kb(_context.get("issuer"), _kb) + _kj = add_kb(_kj, _context.get("issuer"), _kb) else: raise ValueError("Unknown provider JWKS type: {}".format(typ)) @@ -746,7 +748,12 @@ def load_registration_response(client, request_args=None): :param client: A :py:class:`idpyoidc.client.oidc.Client` instance """ - if not client.get_context().get_client_id(): + _client_id = getattr(client, "client_id", None) + if not _client_id: + _context = client.get_context() + _client_id = getattr(_context, "client_id", None) + + if not _client_id: try: response = client.do_request("registration", request_args=request_args) except KeyError: diff --git a/src/idpyoidc/client/oauth2/utils.py b/src/idpyoidc/client/oauth2/utils.py index 254e1bd2..819ff00a 100644 --- a/src/idpyoidc/client/oauth2/utils.py +++ b/src/idpyoidc/client/oauth2/utils.py @@ -2,6 +2,8 @@ from typing import Optional from typing import Union +from cryptojwt import JWT + from idpyoidc.client.defaults import DEFAULT_RESPONSE_MODE from idpyoidc.client.service import Service from idpyoidc.exception import MissingParameter @@ -99,3 +101,19 @@ def set_state_parameter(request_args=None, **kwargs): """Assigned a state value.""" request_args["state"] = get_state_parameter(request_args, kwargs) return request_args, {"state": request_args["state"]} + +def set_request_object(service, request_args): + # construct a signed request object + _context = service.upstream_get("context") + if _context.keyjar: + _jwt = JWT(key_jar=_context.keyjar) + else: + _jwt = JWT(key_jar=service.upstream_get("attribute", "keyjar")) + + if isinstance(request_args, Message): + _request_object = _jwt.pack(request_args.to_dict()) + else: + _request_object = _jwt.pack(request_args) + + # construct the message body + return _request_object \ No newline at end of file diff --git a/src/idpyoidc/client/oidc/access_token.py b/src/idpyoidc/client/oidc/access_token.py index 2024612c..91736d5e 100644 --- a/src/idpyoidc/client/oidc/access_token.py +++ b/src/idpyoidc/client/oidc/access_token.py @@ -2,6 +2,7 @@ from typing import Optional from typing import Union +from idpyoidc.alg_info import get_signing_algs from idpyoidc.client.client_auth import get_client_authn_methods from idpyoidc.client.exception import ParameterError from idpyoidc.client.oauth2 import access_token @@ -9,7 +10,6 @@ from idpyoidc.message import Message from idpyoidc.message import oidc from idpyoidc.message.oidc import verified_claim_name -from idpyoidc.alg_info import get_signing_algs from idpyoidc.time_util import time_sans_frac __author__ = "Roland Hedberg" @@ -34,7 +34,8 @@ def __init__(self, upstream_get, conf: Optional[dict] = None): access_token.AccessToken.__init__(self, upstream_get, conf=conf) def gather_verify_arguments( - self, response: Optional[Union[dict, Message]] = None, behaviour_args: Optional[dict] = None + self, response: Optional[Union[dict, Message]] = None, + behaviour_args: Optional[dict] = None ): """ Need to add some information before running verify() diff --git a/src/idpyoidc/client/oidc/authorization.py b/src/idpyoidc/client/oidc/authorization.py index 73c56929..9eb8c658 100644 --- a/src/idpyoidc/client/oidc/authorization.py +++ b/src/idpyoidc/client/oidc/authorization.py @@ -7,18 +7,16 @@ from idpyoidc.client.oauth2 import authorization from idpyoidc.client.oauth2.utils import pre_construct_pick_redirect_uri from idpyoidc.client.oidc import IDT2REG -from idpyoidc.client.oidc.utils import construct_request_uri -from idpyoidc.client.oidc.utils import request_object_encryption +from idpyoidc.client.request_object import construct_request_parameter from idpyoidc.client.service_context import ServiceContext +from idpyoidc.client.util import construct_request_uri from idpyoidc.client.util import implicit_response_types from idpyoidc.exception import MissingRequiredAttribute from idpyoidc.message import Message from idpyoidc.message import oauth2 from idpyoidc.message import oidc -from idpyoidc.message.oidc import make_openid_request from idpyoidc.message.oidc import verified_claim_name from idpyoidc.time_util import time_sans_frac -from idpyoidc.time_util import utc_time_sans_frac from idpyoidc.util import rndstr __author__ = "Roland Hedberg" @@ -142,7 +140,7 @@ def oidc_pre_construct(self, request_args=None, post_args=None, **kwargs): elif "openid" not in request_args["scope"]: request_args["scope"].append("openid") - # 'code' and/or 'id_token' in response_type means an ID Roken + # 'code' and/or 'id_token' in response_type means an ID Token # will eventually be returned, hence the need for a nonce if "code" in _response_types or "id_token" in _response_types: if "nonce" not in request_args: @@ -173,24 +171,6 @@ def oidc_pre_construct(self, request_args=None, post_args=None, **kwargs): return request_args, post_args - def get_request_object_signing_alg(self, **kwargs): - alg = "" - for arg in ["request_object_signing_alg", "algorithm"]: - try: # Trumps everything - alg = kwargs[arg] - except KeyError: - pass - else: - break - - if not alg: - _context = self.upstream_get("context") - try: - alg = _context.claims.get_usage("request_object_signing_alg") - except KeyError: # Use default - alg = "RS256" - return alg - def store_request_on_file(self, req, **kwargs): """ Stores the request parameter in a file. @@ -212,63 +192,6 @@ def store_request_on_file(self, req, **kwargs): fid.close() return _webname - def construct_request_parameter( - self, req, request_param, audience=None, expires_in=0, **kwargs - ): - """Construct a request parameter""" - alg = self.get_request_object_signing_alg(**kwargs) - kwargs["request_object_signing_alg"] = alg - - _context = self.upstream_get("context") - if "keys" not in kwargs and alg and alg != "none": - kwargs["keys"] = self.upstream_get("attribute", "keyjar") - - if alg == "none": - kwargs["keys"] = [] - - # This is the issuer of the JWT, that is me ! - _issuer = kwargs.get("issuer") - if _issuer is None: - kwargs["issuer"] = _context.get_client_id() - - if kwargs.get("recv") is None: - try: - kwargs["recv"] = _context.provider_info["issuer"] - except KeyError: - kwargs["recv"] = _context.issuer - - try: - del kwargs["service"] - except KeyError: - pass - - if expires_in: - req["exp"] = utc_time_sans_frac() + int(expires_in) - - _mor_args = { - k: kwargs[k] - for k in [ - "keys", - "issuer", - "request_object_signing_alg", - "recv", - "with_jti", - "lifetime", - ] - if k in kwargs - } - - _req_jwt = make_openid_request(req, **_mor_args) - - if "target" not in kwargs: - kwargs["target"] = _context.provider_info.get("issuer", _context.issuer) - - # Should the request be encrypted - _req_jwte = request_object_encryption( - _req_jwt, _context, self.upstream_get("attribute", "keyjar"), **kwargs - ) - return _req_jwte - def oidc_post_construct(self, req, **kwargs): """ Modify the request arguments. @@ -300,13 +223,19 @@ def oidc_post_construct(self, req, **kwargs): _request_param = "request" _req = None # just a flag + kwargs["req"] = req + _service = kwargs.get("service", None) + if _service is None: + kwargs["service"] = self + if _request_param == "request_uri": kwargs["base_path"] = _context.get("base_url") + "/" + "requests" kwargs["local_dir"] = _context.get_usage("requests_dir", "./requests") - _req = self.construct_request_parameter(req, _request_param, **kwargs) + _req = construct_request_parameter(**kwargs) + del kwargs["req"] req["request_uri"] = self.store_request_on_file(_req, **kwargs) elif _request_param == "request": - _req = self.construct_request_parameter(req, _request_param, **kwargs) + _req = construct_request_parameter(**kwargs) req["request"] = _req if _req: diff --git a/src/idpyoidc/client/oidc/registration.py b/src/idpyoidc/client/oidc/registration.py index 49339053..e0a1363d 100644 --- a/src/idpyoidc/client/oidc/registration.py +++ b/src/idpyoidc/client/oidc/registration.py @@ -4,6 +4,7 @@ from idpyoidc.client.entity import response_types_to_grant_types from idpyoidc.client.service import Service +from idpyoidc.key_import import import_jwks from idpyoidc.message import oidc from idpyoidc.message.oauth2 import ResponseMessage @@ -75,7 +76,7 @@ def update_service_context(self, resp, key="", **kwargs): _keyjar = self.upstream_get("attribute", "keyjar") if _keyjar: if _client_id not in _keyjar: - _keyjar.import_jwks(_keyjar.export_jwks(True, ""), issuer_id=_client_id) + _keyjar = import_jwks(_keyjar, _keyjar.export_jwks(True, ""), _client_id) _client_secret = _context.get_usage("client_secret") if _client_secret: if not _keyjar: diff --git a/src/idpyoidc/client/oidc/utils.py b/src/idpyoidc/client/oidc/utils.py deleted file mode 100644 index 2b428feb..00000000 --- a/src/idpyoidc/client/oidc/utils.py +++ /dev/null @@ -1,85 +0,0 @@ -import os - -from cryptojwt.jwe.jwe import JWE -from cryptojwt.jwe.utils import alg2keytype - -from idpyoidc.exception import MissingRequiredAttribute -from idpyoidc.util import rndstr - - -def request_object_encryption(msg, service_context, keyjar, **kwargs): - """ - Created an encrypted JSON Web token with *msg* as body. - - :param msg: The mesaqg - :param service_context: - :param kwargs: - :return: - """ - try: - encalg = kwargs["request_object_encryption_alg"] - except KeyError: - try: - encalg = service_context.get_usage("request_object_encryption_alg") - except KeyError: - return msg - - if not encalg: - return msg - - try: - encenc = kwargs["request_object_encryption_enc"] - except KeyError: - try: - encenc = service_context.get_usage("request_object_encryption_enc") - except KeyError: - raise MissingRequiredAttribute("No request_object_encryption_enc specified") - - if not encenc: - raise MissingRequiredAttribute("No request_object_encryption_enc specified") - - _jwe = JWE(msg, alg=encalg, enc=encenc) - _kty = alg2keytype(encalg) - - try: - _kid = kwargs["enc_kid"] - except KeyError: - _kid = "" - - _target = kwargs.get("target", kwargs.get("recv", None)) - if _target is None: - raise MissingRequiredAttribute("No target specified") - - if _kid: - _keys = keyjar.get_encrypt_key(_kty, issuer_id=_target, kid=_kid) - _jwe["kid"] = _kid - else: - _keys = keyjar.get_encrypt_key(_kty, issuer_id=_target) - - return _jwe.encrypt(_keys) - - -def construct_request_uri(local_dir, base_path, **kwargs): - """ - Constructs a special redirect_uri to be used when communicating with - one OP. Each OP should get their own redirect_uris. - - :param local_dir: Local directory in which to place the file - :param base_path: Base URL to start with - :param kwargs: - :return: 2-tuple with (filename, url) - """ - _filedir = local_dir - if not os.path.isdir(_filedir): - os.makedirs(_filedir) - _webpath = base_path - _name = rndstr(10) + ".jwt" - filename = os.path.join(_filedir, _name) - while os.path.exists(filename): - _name = rndstr(10) - filename = os.path.join(_filedir, _name) - if _webpath.endswith("/"): - _webname = f"{_webpath}{_name}" - else: - _webname = f"{_webpath}/{_name}" - return filename, _webname diff --git a/src/idpyoidc/client/request_object.py b/src/idpyoidc/client/request_object.py new file mode 100644 index 00000000..35520639 --- /dev/null +++ b/src/idpyoidc/client/request_object.py @@ -0,0 +1,141 @@ +from typing import Optional +from typing import Union + +from cryptojwt.jwe.jwe import JWE +from cryptojwt.jwe.utils import alg2keytype +from cryptojwt.jwt import utc_time_sans_frac + +from idpyoidc.defaults import DEF_SIGN_ALG +from idpyoidc.exception import MissingRequiredAttribute +from idpyoidc.message import Message +from idpyoidc.message.oidc import make_openid_request + + +def request_object_encryption(msg, service_context, keyjar, **kwargs): + """ + Created an encrypted JSON Web token with *msg* as body. + + :param msg: The message + :param service_context: + :param kwargs: + :return: + """ + try: + encalg = kwargs["request_object_encryption_alg"] + except KeyError: + try: + encalg = service_context.get_usage("request_object_encryption_alg") + except KeyError: + return msg + + if not encalg: + return msg + + try: + encenc = kwargs["request_object_encryption_enc"] + except KeyError: + try: + encenc = service_context.get_usage("request_object_encryption_enc") + except KeyError: + raise MissingRequiredAttribute("No request_object_encryption_enc specified") + + if not encenc: + raise MissingRequiredAttribute("No request_object_encryption_enc specified") + + _jwe = JWE(msg, alg=encalg, enc=encenc) + _kty = alg2keytype(encalg) + + try: + _kid = kwargs["enc_kid"] + except KeyError: + _kid = "" + + _target = kwargs.get("target", kwargs.get("recv", None)) + if _target is None: + raise MissingRequiredAttribute("No target specified") + + if _kid: + _keys = keyjar.get_encrypt_key(_kty, issuer_id=_target, kid=_kid) + _jwe["kid"] = _kid + else: + _keys = keyjar.get_encrypt_key(_kty, issuer_id=_target) + + return _jwe.encrypt(_keys) + + +def get_request_object_signing_alg(self, **kwargs): + alg = "" + for arg in ["request_object_signing_alg", "algorithm"]: + try: # Trumps everything + alg = kwargs[arg] + except KeyError: + pass + else: + break + + if not alg: + _context = self.upstream_get("context") + try: + alg = _context.claims.get_usage("request_object_signing_alg") + except KeyError: # Use default + pass + + if not alg: + alg = DEF_SIGN_ALG["request_object"] + + return alg + + +def construct_request_parameter( + service, + req: Union[Message, dict], + expires_in: Optional[int] = 0, + **kwargs): + """Construct a request parameter""" + alg = get_request_object_signing_alg(service, **kwargs) + kwargs["request_object_signing_alg"] = alg + + _context = service.upstream_get("context") + if "keys" not in kwargs: + kwargs["keys"] = service.upstream_get("attribute", "keyjar") + + if alg == "none": + kwargs["keys"] = [] + + # This is the issuer of the JWT, that is me ! + _issuer = kwargs.get("issuer") + if _issuer is None: + kwargs["issuer"] = _context.get_client_id() + + if kwargs.get("recv") is None: + try: + kwargs["recv"] = _context.provider_info["issuer"] + except KeyError: + kwargs["recv"] = _context.issuer + + if expires_in: + req["exp"] = utc_time_sans_frac() + int(expires_in) + + _mor_args = { + k: kwargs[k] + for k in [ + "keys", + "issuer", + "request_object_signing_alg", + "recv", + "with_jti", + "lifetime", + ] + if k in kwargs + } + + _req_jwt = make_openid_request(req, **_mor_args) + + if "target" not in kwargs: + kwargs["target"] = _context.provider_info.get("issuer", _context.issuer) + + # Should the request be encrypted + _req_jwte = request_object_encryption( + _req_jwt, _context, service.upstream_get("attribute", "keyjar"), **kwargs + ) + return _req_jwte diff --git a/src/idpyoidc/client/service_context.py b/src/idpyoidc/client/service_context.py index b4d391c2..3959a43e 100644 --- a/src/idpyoidc/client/service_context.py +++ b/src/idpyoidc/client/service_context.py @@ -386,6 +386,8 @@ def prefer_or_support(self, claim): return None def map_supported_to_preferred(self, info: Optional[dict] = None): + # goes from what the entity can do to something the opponent could handle + # info is metadata for the opponent if known self.claims.prefer = supported_to_preferred( self.supports(), self.claims.prefer, base_url=self.base_url, info=info ) diff --git a/src/idpyoidc/client/util.py b/src/idpyoidc/client/util.py index 03084d20..ed738897 100755 --- a/src/idpyoidc/client/util.py +++ b/src/idpyoidc/client/util.py @@ -1,5 +1,6 @@ """Utilities""" import logging +import os import secrets from http.cookiejar import Cookie from http.cookiejar import http2time @@ -14,9 +15,9 @@ from idpyoidc.defaults import BASECHR from idpyoidc.exception import UnSupported from idpyoidc.util import importer - from .exception import TimeFormatError from .exception import WrongContentType +from ..util import rndstr logger = logging.getLogger(__name__) @@ -274,7 +275,7 @@ def get_deserialization_method(reqresp): deser_method = "jose" elif match_to_(URL_ENCODED, _ctype): deser_method = "urlencoded" - elif match_to_("text/plain", _ctype) or match_to_("test/html", _ctype): + elif match_to_("text/plain", _ctype) or match_to_("text/html", _ctype): deser_method = "" else: deser_method = "" # reasonable default ?? @@ -330,3 +331,29 @@ def implicit_response_types(a): def get_uri(base_url, path, hex): return f"{base_url}/{path}/{hex}" + + +def construct_request_uri(local_dir, base_path, **kwargs): + """ + Constructs a special redirect_uri to be used when communicating with + one OP. Each OP should get their own redirect_uris. + + :param local_dir: Local directory in which to place the file + :param base_path: Base URL to start with + :param kwargs: + :return: 2-tuple with (filename, url) + """ + _filedir = local_dir + if not os.path.isdir(_filedir): + os.makedirs(_filedir) + _webpath = base_path + _name = rndstr(10) + ".jwt" + filename = os.path.join(_filedir, _name) + while os.path.exists(filename): + _name = rndstr(10) + filename = os.path.join(_filedir, _name) + if _webpath.endswith("/"): + _webname = f"{_webpath}{_name}" + else: + _webname = f"{_webpath}/{_name}" + return filename, _webname diff --git a/src/idpyoidc/encrypter.py b/src/idpyoidc/encrypter.py index f9a2052a..844618f3 100644 --- a/src/idpyoidc/encrypter.py +++ b/src/idpyoidc/encrypter.py @@ -2,6 +2,7 @@ from typing import Optional from cryptojwt.key_jar import init_key_jar +from cryptojwt.utils import as_bytes from idpyoidc.util import instantiate @@ -98,6 +99,16 @@ def init_encrypter(conf: Optional[dict] = None): if attr == "keys": continue _kwargs[attr] = val + + _key = _kwargs.get("key") + if _key: + if isinstance(_key, bytes): + pass + elif isinstance(_key, str): + _kwargs["key"] = as_bytes(_key) + else: + raise ValueError("Raw key most be of type bytes") + return { "encrypter": instantiate(_class, **_kwargs), "conf": {"class": _class, "kwargs": _kwargs}, diff --git a/src/idpyoidc/message/__init__.py b/src/idpyoidc/message/__init__.py index 46d23440..67e56f84 100644 --- a/src/idpyoidc/message/__init__.py +++ b/src/idpyoidc/message/__init__.py @@ -388,7 +388,7 @@ def _add_value(self, skey, vtyp, key, val, _deser, null_allowed, sformat="urlenc else: self._dict[skey] = val else: - raise DecodeError(ERRTXT % (key, "type != %s" % vtype)) + raise DecodeError(ERRTXT % (key, f"type != {vtype}, val:{val}, type:{type(val)}")) else: if val is None: self._dict[skey] = None diff --git a/src/idpyoidc/message/oauth2/__init__.py b/src/idpyoidc/message/oauth2/__init__.py index 788fe8c5..105b2865 100644 --- a/src/idpyoidc/message/oauth2/__init__.py +++ b/src/idpyoidc/message/oauth2/__init__.py @@ -559,6 +559,10 @@ def verify(self, **kwargs): return True +class PushedAuthorizationResponse(ResponseMessage): + c_param = ResponseMessage.c_param.copy() + c_param.update({"request_uri": SINGLE_REQUIRED_STRING}) + class SecurityEventToken(Message): c_param = { diff --git a/src/idpyoidc/message/oidc/__init__.py b/src/idpyoidc/message/oidc/__init__.py index fc5d114c..8eeb2954 100644 --- a/src/idpyoidc/message/oidc/__init__.py +++ b/src/idpyoidc/message/oidc/__init__.py @@ -1025,7 +1025,7 @@ def verify(self, **kwargs): except KeyError: pass - if "iss" in kwargs and "iss" in self: + if "iss" in kwargs and kwargs["iss"] and "iss" in self: if kwargs["iss"] != self["iss"]: raise ValueError("Wrong issuer") @@ -1191,7 +1191,7 @@ def make_openid_request( :param request_object_signing_alg: Which signing algorithm to use :param recv: The intended receiver of the request :param with_jti: Whether a JTI should be included in the JWT. - :param lifetime: How long the JWT is expect to be live. + :param lifetime: How long the JWT is expected to be live. :return: JWT encoded OpenID request """ @@ -1200,7 +1200,9 @@ def make_openid_request( _jwt.with_jti = True if lifetime: _jwt.lifetime = lifetime - return _jwt.pack(arq.to_dict(), owner=issuer, recv=recv) + if isinstance(arq, Message): + arq = arq.to_dict() + return _jwt.pack(arq, owner=issuer, recv=recv) def claims_match(value, claimspec): diff --git a/src/idpyoidc/metadata.py b/src/idpyoidc/metadata.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/idpyoidc/server/__init__.py b/src/idpyoidc/server/__init__.py index 78c2370b..b8c9671e 100644 --- a/src/idpyoidc/server/__init__.py +++ b/src/idpyoidc/server/__init__.py @@ -6,6 +6,7 @@ from typing import Union from cryptojwt import KeyJar +from cryptojwt.utils import importer from idpyoidc.client.defaults import DEFAULT_KEY_DEFS from idpyoidc.node import Unit @@ -52,6 +53,9 @@ def __init__( if _conf: self.entity_id = _conf.get("entity_id", "") self.issuer = conf.get("issuer", self.entity_id) + if not self.entity_id and self.issuer: + self.entity_id = self.issuer + self.persistence = None if upstream_get is None: @@ -95,6 +99,20 @@ def __init__( _token_endp = self.endpoint.get("token") + if isinstance(conf, dict): + metadata_schema = conf.get("metadata_schema", None) + else: + metadata_schema = conf.conf.get("metadata_schema", None) + + if metadata_schema: + metadata_schema = importer(metadata_schema) + self.context.provider_info = self.context.claims.get_server_metadata( + endpoints = self.endpoint.values(), + metadata_schema = metadata_schema, + ) + self.context.provider_info["issuer"] = self.issuer + self.context.metadata = self.context.provider_info + self.context.map_supported_to_preferred() if _token_endp: _token_endp.allow_refresh = allow_refresh_token(self.context) diff --git a/src/idpyoidc/server/claims/oauth2.py b/src/idpyoidc/server/claims/oauth2.py index 86e969df..243e09b3 100644 --- a/src/idpyoidc/server/claims/oauth2.py +++ b/src/idpyoidc/server/claims/oauth2.py @@ -1,5 +1,6 @@ from typing import Optional +from idpyoidc.message import Message from idpyoidc.message.oauth2 import ASConfigurationResponse from idpyoidc.server import claims @@ -38,9 +39,12 @@ class Claims(claims.Claims): def __init__(self, prefer: Optional[dict] = None, callback_path: Optional[dict] = None): claims.Claims.__init__(self, prefer=prefer, callback_path=callback_path) - def provider_info(self, supports): + def metadata(self, supports, schema: Optional[Message] = None): _info = {} - for key in ASConfigurationResponse.c_param.keys(): + if schema is None: + schema = ASConfigurationResponse + + for key in schema.c_param.keys(): _val = self.get_preference(key, supports.get(key, None)) if _val and _val != []: _info[key] = _val diff --git a/src/idpyoidc/server/claims/oidc.py b/src/idpyoidc/server/claims/oidc.py index 6d5efd6a..91ea2154 100644 --- a/src/idpyoidc/server/claims/oidc.py +++ b/src/idpyoidc/server/claims/oidc.py @@ -1,6 +1,7 @@ from typing import Optional from idpyoidc import alg_info +from idpyoidc.message import Message from idpyoidc.message.oidc import ProviderConfigurationResponse from idpyoidc.message.oidc import RegistrationRequest from idpyoidc.message.oidc import RegistrationResponse @@ -91,9 +92,12 @@ def verify_rules(self, supports): self.set_preference("id_token_encryption_alg_values_supported", []) self.set_preference("id_token_encryption_enc_values_supported", []) - def provider_info(self, supports): + def metadata(self, supports, schema: Optional[Message] = None): _info = {} - for key in ProviderConfigurationResponse.c_param.keys(): + if schema is None: + schema = ProviderConfigurationResponse + + for key in schema.c_param.keys(): _val = self.get_preference(key, supports.get(key, None)) if _val not in [None, []]: _info[key] = _val diff --git a/src/idpyoidc/server/client_authn.py b/src/idpyoidc/server/client_authn.py index 8a0c72da..69a0a512 100755 --- a/src/idpyoidc/server/client_authn.py +++ b/src/idpyoidc/server/client_authn.py @@ -99,7 +99,7 @@ def basic_authn(authorization_token: str): _tok = as_bytes(authorization_token[6:]) # Will raise ValueError type exception if not base64 encoded _tok = base64.b64decode(_tok) - part = as_unicode(_tok).split(":", 1) + part = as_unicode(_tok).rsplit(":", 1) if len(part) != 2: raise ValueError("Illegal token") @@ -427,6 +427,30 @@ def _verify( return {"client_id": client_id, "jwt": _jwt} +class PushedAuthorization(ClientAuthnMethod): + # The premise here is that there has been a client authentication at the + # pushed authorization endpoint + tag = "pushed_authz" + + def is_usable(self, request=None, authorization_token=None, http_info: Optional[dict] = None): + _request_uri = request.get("request_uri", None) + if _request_uri: + _context = self.upstream_get("context") + if _request_uri.startswith("urn:uuid:") and _request_uri in _context.par_db: + return True + + def _verify( + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + http_info: Optional[dict] = None, + **kwargs, + ): + client_id = request["client_id"] + return {"client_id": client_id} + + CLIENT_AUTHN_METHOD = dict( client_secret_basic=ClientSecretBasic, client_secret_post=ClientSecretPost, @@ -437,6 +461,7 @@ def _verify( request_param=RequestParam, public=PublicAuthn, none=NoneAuthn, + pushed_authz=PushedAuthorization ) TYPE_METHOD = [(JWT_BEARER, JWSAuthnMethod)] diff --git a/src/idpyoidc/server/endpoint.py b/src/idpyoidc/server/endpoint.py index 07c0080b..038d234e 100755 --- a/src/idpyoidc/server/endpoint.py +++ b/src/idpyoidc/server/endpoint.py @@ -181,11 +181,11 @@ def verify_request(self, request, keyjar, client_id, verify_args, lap=0): return None def parse_request( - self, - request: Union[Message, dict, str], - http_info: Optional[dict] = None, - verify_args: Optional[dict] = None, - **kwargs + self, + request: Union[Message, dict, str], + http_info: Optional[dict] = None, + verify_args: Optional[dict] = None, + **kwargs ): """ @@ -195,8 +195,8 @@ def parse_request( :param kwargs: extra keyword arguments :return: """ - LOGGER.debug("- {} -".format(self.endpoint_name)) - LOGGER.info("Request: %s" % sanitize(request)) + LOGGER.debug(f"- {self.endpoint_name} -") + LOGGER.info(f"Request: {sanitize(request)}") _context = self.upstream_get("context") _keyjar = self.upstream_get("attribute", "keyjar") @@ -228,7 +228,6 @@ def parse_request( # Verify that the client is allowed to do this auth_info = self.client_authentication(req, http_info, endpoint=self, **kwargs) - LOGGER.debug(f"parse_request:auth_info:{auth_info}") _client_id = auth_info.get("client_id", "") if _client_id: @@ -274,17 +273,18 @@ def client_authentication(self, request: Message, http_info: Optional[dict] = No authn_info = verify_client(request=request, http_info=http_info, **kwargs) - LOGGER.debug("authn_info: %s", authn_info) + LOGGER.debug(f"authn_info: {authn_info}") if authn_info == {}: if self.client_authn_method and len(self.client_authn_method): - LOGGER.debug("client_authn_method: %s", self.client_authn_method) + LOGGER.debug(f"client_authn_method: {self.client_authn_method}") raise UnAuthorizedClient("Authorization failed") elif "client_id" not in authn_info and authn_info.get("method") != "none": + LOGGER.debug(f"No client ID") raise UnAuthorizedClient("Authorization failed") return authn_info def do_post_parse_request( - self, request: Message, client_id: Optional[str] = "", **kwargs + self, request: Message, client_id: Optional[str] = "", **kwargs ) -> Message: _context = self.upstream_get("context") for meth in self.post_parse_request: @@ -294,7 +294,7 @@ def do_post_parse_request( return request def do_pre_construct( - self, response_args: dict, request: Optional[Union[Message, dict]] = None, **kwargs + self, response_args: dict, request: Optional[Union[Message, dict]] = None, **kwargs ) -> dict: _context = self.upstream_get("context") for meth in self.pre_construct: @@ -303,10 +303,10 @@ def do_pre_construct( return response_args def do_post_construct( - self, - response_args: Union[Message, dict], - request: Optional[Union[Message, dict]] = None, - **kwargs + self, + response_args: Union[Message, dict], + request: Optional[Union[Message, dict]] = None, + **kwargs ) -> dict: _context = self.upstream_get("context") for meth in self.post_construct: @@ -315,10 +315,10 @@ def do_post_construct( return response_args def process_request( - self, - request: Optional[Union[Message, dict]] = None, - http_info: Optional[dict] = None, - **kwargs + self, + request: Optional[Union[Message, dict]] = None, + http_info: Optional[dict] = None, + **kwargs ) -> Union[Message, dict]: """ @@ -329,10 +329,10 @@ def process_request( return {} def construct( - self, - response_args: Optional[dict] = None, - request: Optional[Union[Message, dict]] = None, - **kwargs + self, + response_args: Optional[dict] = None, + request: Optional[Union[Message, dict]] = None, + **kwargs ): """ Construct the response @@ -350,19 +350,35 @@ def construct( return self.do_post_construct(response, request, **kwargs) def response_info( - self, - response_args: Optional[dict] = None, - request: Optional[Union[Message, dict]] = None, - **kwargs + self, + response_args: Optional[dict] = None, + request: Optional[Union[Message, dict]] = None, + **kwargs ) -> dict: return self.construct(response_args, request, **kwargs) + def _get_content_type(self, **kwargs): + content_type = kwargs.get("content_type", None) + + if content_type is None: + if self.response_content_type: + content_type = self.response_content_type + elif self.response_format == "json": + content_type = "application/json" + elif self.response_format in ["jws", "jwe", "jose"]: + content_type = "application/jose" + elif self.response_format == "text": + content_type = "text/plain" + else: + content_type = "application/x-www-form-urlencoded" + return content_type + def do_response( - self, - response_args: Optional[dict] = None, - request: Optional[Union[Message, dict]] = None, - error: Optional[str] = "", - **kwargs + self, + response_args: Optional[dict] = None, + request: Optional[Union[Message, dict]] = None, + error: Optional[str] = "", + **kwargs ) -> dict: """ :param response_args: Information to use when constructing the response @@ -380,6 +396,7 @@ def do_response( resp = None if error: + content_type = "text/html" _response = ResponseMessage(error=error) for attr in ["error_description", "error_uri", "state"]: if attr in kwargs: @@ -389,58 +406,46 @@ def do_response( _response_placement = kwargs.get("response_placement") do_placement = False _response = "" - content_type = kwargs.get("content_type") - if content_type is None: - if self.response_content_type: - content_type = self.response_content_type - elif self.response_format == "json": - content_type = "application/json" - elif self.response_format in ["jws", "jwe", "jose"]: - content_type = "application/jose" - elif self.response_format == "text": - content_type = "text/plain" - else: - content_type = "application/x-www-form-urlencoded" + content_type = self._get_content_type(**kwargs) else: + content_type = "" _response = self.response_info(response_args, request, **kwargs) if do_placement: - content_type = kwargs.get("content_type") - if content_type is None: - if self.response_placement == "body": - if self.response_format == "json": - content_type = "application/json; charset=utf-8" - if isinstance(_response, Message): - resp = _response.to_json() - else: - resp = json.dumps(_response) - elif self.response_format in ["jws", "jwe", "jose"]: - if self.response_content_type: - content_type = self.response_content_type - else: - content_type = "application/jose; charset=utf-8" - resp = _response + if not content_type: + content_type = self._get_content_type(**kwargs) + if self.response_placement == "body": + if self.response_format == "json": + if not content_type: + content_type = "application/json" + if isinstance(_response, Message): + resp = _response.to_json() else: + resp = json.dumps(_response) + elif self.response_format in ["jws", "jwe", "jose"]: + if not content_type: + content_type = "application/jose" + resp = _response + else: + if not content_type: content_type = "application/x-www-form-urlencoded" - resp = _response.to_urlencoded() - elif self.response_placement == "url": + resp = _response.to_urlencoded() + elif self.response_placement == "url": + if not content_type: content_type = "application/x-www-form-urlencoded" - fragment_enc = kwargs.get("fragment_enc") - if not fragment_enc: - _ret_type = kwargs.get("return_type") - if _ret_type: - fragment_enc = fragment_encoding(_ret_type) - else: - fragment_enc = False - - if fragment_enc: - resp = _response.request(kwargs["return_uri"], True) + + fragment_enc = kwargs.get("fragment_enc") + if not fragment_enc: + _ret_type = kwargs.get("return_type") + if _ret_type: + fragment_enc = fragment_encoding(_ret_type) else: - resp = _response.request(kwargs["return_uri"]) - else: - raise ValueError( - "Don't know where that is: '{}".format(self.response_placement) - ) + fragment_enc = False + + resp = _response.request(kwargs["return_uri"], fragment_enc=fragment_enc) + else: + raise ValueError( + f"Don't know how to handle response_placement='{self.response_placement}'") if content_type: try: diff --git a/src/idpyoidc/server/endpoint_context.py b/src/idpyoidc/server/endpoint_context.py index 3b46ef3e..e70c2d4d 100755 --- a/src/idpyoidc/server/endpoint_context.py +++ b/src/idpyoidc/server/endpoint_context.py @@ -11,6 +11,7 @@ from requests import request from idpyoidc.context import OidcContext +from idpyoidc.message import Message from idpyoidc.server import authz from idpyoidc.server.claims import Claims from idpyoidc.server.claims.oauth2 import Claims as OAUTH2_Claims @@ -173,6 +174,7 @@ def __init__( self.token_args_methods = [] self.userinfo = None self.client_authn_method = {} + self.client_known_as = {} for param in [ "issuer", @@ -186,8 +188,6 @@ def __init__( except KeyError: pass - self.token_handler_args = get_token_handler_args(conf) - # session db self._sub_func = {} self.do_sub_func() @@ -240,9 +240,6 @@ def __init__( conf = conf.conf _supports = self.supports() self.keyjar = self.claims.load_conf(conf, supports=_supports, keyjar=keyjar) - self.provider_info = self.claims.provider_info(_supports) - self.provider_info["issuer"] = self.issuer - self.provider_info.update(self._get_endpoint_info()) # INTERFACES @@ -250,23 +247,34 @@ def __init__( self.setup_authentication() - self.session_manager = SessionManager( - self.token_handler_args, - sub_func=self._sub_func, - conf=conf, - upstream_get=self.unit_get) + # default is to have session management + if self.conf.get("session_management", self.conf["conf"].get("session_management", True)): + self.token_handler_args = get_token_handler_args(self.conf) + + self.session_manager = SessionManager( + self.token_handler_args, + sub_func = self._sub_func, + conf = conf, + upstream_get = self.unit_get) + else: + self.session_manager = None self.do_userinfo() # Must be done after userinfo self.setup_login_hint_lookup() - self.set_remember_token() + if self.session_manager: + self.set_remember_token() self.setup_client_authn_methods() - # _id_token_handler = self.session_manager.token_handler.handler.get("id_token") - # if _id_token_handler: - # self.provider_info.update(_id_token_handler.provider_info) + def get_metadata(self, supports: Optional[dict] = None, schema: Optional[Message] = None): + if supports is None: + supports = self.supports() + + _metadata = self.claims.metadata(supports, schema) + _metadata.update(self._get_endpoint_info()) + return _metadata def setup_authz(self): authz_spec = self.conf.get("authz") @@ -407,7 +415,7 @@ def supports(self): return res def set_provider_info(self): - _info = self.claims.provider_info(self.supports()) + _info = self.claims.metadata(self.supports()) _info.update({"issuer": self.issuer, "version": "3.0"}) for endp in self.upstream_get("endpoints").values(): diff --git a/src/idpyoidc/server/oauth2/add_on/dpop.py b/src/idpyoidc/server/oauth2/add_on/dpop.py index 2e7ae1e5..24209629 100644 --- a/src/idpyoidc/server/oauth2/add_on/dpop.py +++ b/src/idpyoidc/server/oauth2/add_on/dpop.py @@ -1,3 +1,4 @@ +import base64 import logging from hashlib import sha256 from typing import Callable @@ -8,13 +9,14 @@ from cryptojwt import JWS from cryptojwt.jwk.jwk import key_from_jwk_dict from cryptojwt.jws.jws import factory +from cryptojwt.utils import add_padding +from idpyoidc.alg_info import get_signing_algs from idpyoidc.message import Message from idpyoidc.message import SINGLE_OPTIONAL_STRING from idpyoidc.message import SINGLE_REQUIRED_INT from idpyoidc.message import SINGLE_REQUIRED_JSON from idpyoidc.message import SINGLE_REQUIRED_STRING -from idpyoidc.alg_info import get_signing_algs from idpyoidc.server.client_authn import BearerHeader logger = logging.getLogger(__name__) @@ -107,7 +109,14 @@ def token_post_parse_request(request, client_id, context, **kwargs): if not _http_info: return request - _dpop = DPoPProof().verify_header(_http_info["headers"]["dpop"]) + _headers = _http_info['headers'] + logger.debug(f"http headers: {_headers}") + + _dpop_header = _headers.get("dpop", _headers.get("http_dpop", None)) + if not _dpop_header: + raise ValueError("Missing DPoP header") + + _dpop = DPoPProof().verify_header(_dpop_header) # The signature of the JWS is verified, now for checking the # content @@ -130,7 +139,7 @@ def userinfo_post_parse_request(request, client_id, context, auth_info, **kwargs """ Expect http_info attribute in kwargs. http_info should be a dictionary containing HTTP information. - This function is ment for DPoP-protected resources. + This function is meant for DPoP-protected resources. :param request: :param client_id: @@ -144,6 +153,18 @@ def userinfo_post_parse_request(request, client_id, context, auth_info, **kwargs return request _dpop = DPoPProof().verify_header(_http_info["headers"]["dpop"]) + _headers = _http_info.get("headers", "") + if _headers: + _dpop_header = _headers.get("dpop", "") + if not _dpop_header: + _dpop_header = _headers.get("http_dpop", "") + if not _dpop_header: + logger.debug(f"Request Headers: {_headers}") + raise ValueError("Expected DPoP header, none found") + else: + raise ValueError("Expected DPoP header, no headers found") + + _dpop = DPoPProof().verify_header(_dpop_header) # The signature of the JWS is verified, now for checking the # content @@ -158,9 +179,19 @@ def userinfo_post_parse_request(request, client_id, context, auth_info, **kwargs _dpop.key = key_from_jwk_dict(_dpop["jwk"]) ath = sha256(auth_info["token"].encode("utf8")).hexdigest() + _token = auth_info.get("token", None) + if _token: + ath = as_unicode(base64.urlsafe_b64encode(sha256(_token.encode("utf8")).digest())) if _dpop["ath"] != ath: - raise ValueError("'ath' in DPoP does not match the token hash") + _ath = _dpop.get("ath", None) + if _ath is None: + raise ValueError("'ath' missing from DPoP") + else: + _athb = _ath.rstrip("=") + _ath = add_padding(_athb) + if _ath != ath: + raise ValueError("'ath' in DPoP does not match the token hash") # Need something I can add as a reference when minting tokens request["dpop_jkt"] = as_unicode(_dpop.key.thumbprint("SHA-256")) @@ -184,34 +215,32 @@ def _add_to_context(endpoint, algs_supported): _context = endpoint.upstream_get("context") _context.provider_info["dpop_signing_alg_values_supported"] = algs_supported _context.add_on["dpop"] = {"algs_supported": algs_supported} - _context.client_authn_methods["dpop"] = DPoPClientAuth - + _context.client_authn_methods["dpop"] = DPoPClientAuth(endpoint.upstream_get) -def add_support(endpoint: dict, **kwargs): - # Pick the token endpoint - _endp = endpoint.get("token", None) - if _endp: - _endp.post_parse_request.append(token_post_parse_request) - _added_to_context = False - _algs_supported = kwargs.get("dpop_signing_alg_values_supported") - if not _algs_supported: +def add_support(endpoint: dict, dpop_signing_alg_values_supported=None, dpop_endpoints=None, **kwargs): + if dpop_signing_alg_values_supported is None: _algs_supported = ["RS256"] else: - _algs_supported = [alg for alg in _algs_supported if alg in get_signing_algs()] + # Pick out the ones I support + _algs_supported = [alg for alg in dpop_signing_alg_values_supported if alg in get_signing_algs()] + + _added_to_context = False - if _endp: - _add_to_context(_endp, _algs_supported) - _added_to_context = True + if dpop_endpoints is None: + dpop_endpoints = ["userinfo"] - for _dpop_endpoint in kwargs.get("dpop_endpoints", ["userinfo"]): + for _dpop_endpoint in dpop_endpoints: _endpoint = endpoint.get(_dpop_endpoint, None) if _endpoint: if not _added_to_context: - _add_to_context(_endp, _algs_supported) + _add_to_context(_endpoint, _algs_supported) _added_to_context = True - _endpoint.post_parse_request.append(userinfo_post_parse_request) + if _endpoint.name == "userinfo": + _endpoint.post_parse_request.append(userinfo_post_parse_request) + elif _endpoint.name == "token": + _endpoint.post_parse_request.append(token_post_parse_request) # DPoP-bound access token in the "Authorization" header and the DPoP proof in the "DPoP" header @@ -220,7 +249,7 @@ def add_support(endpoint: dict, **kwargs): class DPoPClientAuth(BearerHeader): tag = "dpop_client_auth" - def is_usable(self, request=None, authorization_token=None, http_headers=None): + def is_usable(self, request=None, authorization_token=None, http_info=None): if authorization_token is not None and authorization_token.startswith("DPoP "): return True return False @@ -231,6 +260,7 @@ def verify( authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] get_client_id_from_token: Optional[Callable] = None, + http_info: Optional[dict] = None, **kwargs, ): # info contains token and client_id diff --git a/src/idpyoidc/server/oauth2/authorization.py b/src/idpyoidc/server/oauth2/authorization.py index e2cd4fa7..6c2fc5aa 100755 --- a/src/idpyoidc/server/oauth2/authorization.py +++ b/src/idpyoidc/server/oauth2/authorization.py @@ -430,8 +430,10 @@ def verify_response_type(self, request: Union[Message, dict], cinfo: dict) -> bo # Checking response types _registered = [set(rt.split(" ")) for rt in cinfo.get("response_types_supported", [])] if not _registered: - # If no response_type is registered by the client then we'll use code. - _registered = [{"code"}] + _registered = [set(rt.split(" ")) for rt in cinfo.get("response_types", [])] + if not _registered: + # If no response_type is registered by the client then we'll use code. + _registered = [{"code"}] # Is the asked for response_type among those that are permitted return set(request["response_type"]) in _registered diff --git a/src/idpyoidc/server/oauth2/pushed_authorization.py b/src/idpyoidc/server/oauth2/pushed_authorization.py index 693b073f..0f54373d 100644 --- a/src/idpyoidc/server/oauth2/pushed_authorization.py +++ b/src/idpyoidc/server/oauth2/pushed_authorization.py @@ -40,11 +40,11 @@ def process_request(self, request: Optional[Union[Message, str]] = None, **kwarg _request.verify(keyjar=self.upstream_get("attribute", "keyjar")) - _urn = "urn:uuid:{}".format(uuid.uuid4()) + _urn = f"urn:uuid:{uuid.uuid4()}" # Store the parsed and verified request self.upstream_get("context").par_db[_urn] = _request return { - "http_response": {"request_uri": _urn, "expires_in": self.ttl}, + "response_args": {"request_uri": _urn, "expires_in": self.ttl}, "return_uri": _request["redirect_uri"], } diff --git a/src/idpyoidc/server/oidc/authorization.py b/src/idpyoidc/server/oidc/authorization.py index 29e93886..055774f3 100644 --- a/src/idpyoidc/server/oidc/authorization.py +++ b/src/idpyoidc/server/oidc/authorization.py @@ -83,8 +83,8 @@ class Authorization(authorization.Authorization): "claims_parameter_supported": True, "encrypt_request_object_supported": False, "request_object_signing_alg_values_supported": alg_info.get_signing_algs(), - "request_object_encryption_alg_values_supported": alg_info.get_encryption_algs(), - "request_object_encryption_enc_values_supported": alg_info.get_encryption_encs(), + "request_object_encryption_alg_values_supported": [], + "request_object_encryption_enc_values_supported": [], "request_parameter_supported": None, "request_uri_parameter_supported": None, "require_request_uri_registration": None, diff --git a/src/idpyoidc/server/oidc/backchannel_authentication.py b/src/idpyoidc/server/oidc/backchannel_authentication.py index b193e223..456fb251 100644 --- a/src/idpyoidc/server/oidc/backchannel_authentication.py +++ b/src/idpyoidc/server/oidc/backchannel_authentication.py @@ -86,10 +86,10 @@ def allowed_target_uris(self): return set(res) def process_request( - self, - request: Optional[Union[Message, dict]] = None, - http_info: Optional[dict] = None, - **kwargs, + self, + request: Optional[Union[Message, dict]] = None, + http_info: Optional[dict] = None, + **kwargs, ): try: request_user = self.do_request_user(request) @@ -125,6 +125,7 @@ def process_request( class CIBATokenHelper(AccessTokenHelper): + def _get_session_info(self, request, session_manager): _path = request["_session_path"] _grant = session_manager.get(_path) @@ -137,7 +138,7 @@ def _get_session_info(self, request, session_manager): return session_info, _grant def post_parse_request( - self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs + self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs ) -> Union[Message, dict]: _context = self.endpoint.upstream_get("context") _mngr = _context.session_manager @@ -303,10 +304,10 @@ def __init__(self, upstream_get: Callable, **kwargs): Endpoint.__init__(self, upstream_get, **kwargs) def process_request( - self, - request: Optional[Union[Message, dict]] = None, - http_info: Optional[dict] = None, - **kwargs, + self, + request: Optional[Union[Message, dict]] = None, + http_info: Optional[dict] = None, + **kwargs, ) -> Union[Message, dict]: return {} @@ -316,17 +317,17 @@ class ClientNotificationAuthn(ClientSecretBasic): tag = "client_notification_authn" - def is_usable(self, request=None, authorization_token=None): + def is_usable(self, request=None, authorization_token=None, http_info=None): if authorization_token is not None and authorization_token.startswith("Bearer "): return True return False def _verify( - self, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - get_client_id_from_token: Optional[Callable] = None, - **kwargs, + self, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + get_client_id_from_token: Optional[Callable] = None, + **kwargs, ): ttype, token = authorization_token.split(" ", 1) if ttype != "Bearer": diff --git a/src/idpyoidc/server/oidc/provider_config.py b/src/idpyoidc/server/oidc/provider_config.py index 819a6997..be399ae6 100755 --- a/src/idpyoidc/server/oidc/provider_config.py +++ b/src/idpyoidc/server/oidc/provider_config.py @@ -33,4 +33,13 @@ def add_endpoints(self, request, client_id, context, **kwargs): return request def process_request(self, request=None, **kwargs): - return {"response_args": self.upstream_get("context").provider_info} + _schema = self.upstream_get("attribute", "metadata_schema") + _args = self.upstream_get("context").claims.get_server_metadata(metadata_schema=_schema) + # add issuer + _args["issuer"] = self.upstream_get("attribute", "entity_id") + # add endpoints + for name, endpoint in self.upstream_get("unit").endpoint.items(): + if endpoint.endpoint_name: + _args[endpoint.endpoint_name] = endpoint.full_path + + return {"response_args": _args} diff --git a/src/idpyoidc/server/oidc/registration.py b/src/idpyoidc/server/oidc/registration.py index a363ebeb..b775b406 100644 --- a/src/idpyoidc/server/oidc/registration.py +++ b/src/idpyoidc/server/oidc/registration.py @@ -10,6 +10,9 @@ from cryptojwt.jws.utils import alg2keytype from cryptojwt.utils import as_bytes +from idpyoidc.key_import import import_jwks + +from idpyoidc.key_import import import_jwks_as_json from idpyoidc.exception import MessageException from idpyoidc.message.oauth2 import ResponseMessage @@ -143,7 +146,7 @@ def match_claim(self, claim, val): # Use my defaults _my_key = _context.claims.register2preferred.get(claim, claim) try: - _val = _context.provider_info[_my_key] + _val = _context.claims.get_preference(_my_key) except KeyError: return val @@ -279,9 +282,18 @@ def do_client_registration(self, request, client_id, ignore=None): t = {"jwks_uri": "", "jwks": None} - for item in ["jwks_uri", "jwks"]: - if item in request: - t[item] = request[item] + _jwks_uri = request.get("jwks_uri") + if _jwks_uri: + # if it can't load keys because the URL is false it will + # just silently fail. Waiting for better times. + _keyjar.add_url(issuer_id=client_id, url=_jwks_uri) + else: + _jwks = request.get("jwks", None) + if _jwks: + if isinstance(_jwks, str): + _keyjar = import_jwks_as_json(_keyjar, _jwks, client_id) + else: + _keyjar = import_jwks(_keyjar, _jwks, client_id) # if it can't load keys because the URL is false it will # just silently fail. Waiting for better times. @@ -437,7 +449,13 @@ def client_registration_setup(self, request, if not reserved_client_id: reserved_client_id = _context.cdb.keys() client_id = cid_generator(reserved=reserved_client_id, **cid_gen_kwargs) - if "client_id" in request: + _entity_id = request.get("client_id", None) + if _entity_id: + # Already registered + _old_id = _context.client_known_as.get(request["client_id"], None) + if _old_id: + del _context.cdb[_old_id] + _context.client_known_as[_entity_id] = client_id del request["client_id"] else: client_id = request.get("client_id") @@ -456,7 +474,7 @@ def client_registration_setup(self, request, if set_secret: client_secret = self.add_client_secret(_cinfo, client_id, _context) - logger.debug("Stored client info in CDB under cid={}".format(client_id)) + logger.debug(f"Stored client info in CDB under cid={client_id}") _context.cdb[client_id] = _cinfo _cinfo = self.do_client_registration( @@ -469,6 +487,12 @@ def client_registration_setup(self, request, args = dict([(k, v) for k, v in _cinfo.items() if k in self.response_cls.c_param]) + # Don't echo keys back + try: + del args["jwks"] + except KeyError: + pass + comb_uri(args) response = self.response_cls(**args) @@ -495,7 +519,7 @@ def process_request(self, request=None, new_id=True, set_secret=True, **kwargs): reg_resp = self.client_registration_setup(request, new_id, set_secret, reserved_client_id) except Exception as err: - logger.error("client_registration_setup: %s", request) + logger.exception(f"client_registration_setup: {request}") return ResponseMessage( error="invalid_configuration_request", error_description="%s" % err ) diff --git a/src/idpyoidc/server/oidc/token_helper/access_token.py b/src/idpyoidc/server/oidc/token_helper/access_token.py index 2594748e..eefc4e24 100755 --- a/src/idpyoidc/server/oidc/token_helper/access_token.py +++ b/src/idpyoidc/server/oidc/token_helper/access_token.py @@ -17,6 +17,7 @@ class AccessTokenHelper(TokenEndpointHelper): + def _get_session_info(self, request, session_manager): if request["grant_type"] != "authorization_code": return self.error_cls(error="invalid_request", error_description="Unknown grant_type") @@ -56,7 +57,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): if "grant_types_supported" in _context.cdb[client_id]: grant_types_supported = _context.cdb[client_id].get("grant_types_supported") else: - grant_types_supported = _context.provider_info["grant_types_supported"] + grant_types_supported = _context.provider_info.get("grant_types", []) grant = _session_info["grant"] token_type = "Bearer" @@ -166,7 +167,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): return _response def post_parse_request( - self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs + self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs ) -> Union[Message, dict]: """ This is where clients come to get their access tokens diff --git a/src/idpyoidc/server/oidc/userinfo.py b/src/idpyoidc/server/oidc/userinfo.py index 27557047..54a8b93f 100755 --- a/src/idpyoidc/server/oidc/userinfo.py +++ b/src/idpyoidc/server/oidc/userinfo.py @@ -34,8 +34,8 @@ class UserInfo(Endpoint): "claim_types_supported": ["normal", "aggregated", "distributed"], "encrypt_userinfo_supported": True, "userinfo_signing_alg_values_supported": alg_info.get_signing_algs(), - "userinfo_encryption_alg_values_supported": alg_info.get_encryption_algs(), - "userinfo_encryption_enc_values_supported": alg_info.get_encryption_encs(), + "userinfo_encryption_alg_values_supported": [], + "userinfo_encryption_enc_values_supported": [], } def __init__( diff --git a/src/idpyoidc/server/session/grant.py b/src/idpyoidc/server/session/grant.py index d7ee7c39..d615a9b2 100644 --- a/src/idpyoidc/server/session/grant.py +++ b/src/idpyoidc/server/session/grant.py @@ -381,7 +381,7 @@ def mint_token( item.value = token_handler( session_id=session_id, usage_rules=usage_rules, **token_payload ) - + logger.debug(f"Minted token value: {item.value}") if based_on: based_on.used += 1 else: diff --git a/src/idpyoidc/server/session/manager.py b/src/idpyoidc/server/session/manager.py index ddd5a9dc..d209be60 100644 --- a/src/idpyoidc/server/session/manager.py +++ b/src/idpyoidc/server/session/manager.py @@ -1,10 +1,10 @@ import hashlib import logging import os -import uuid from typing import Callable from typing import List from typing import Optional +import uuid from idpyoidc.encrypter import default_crypt_config from idpyoidc.message.oauth2 import AuthorizationRequest @@ -13,16 +13,15 @@ from idpyoidc.server.exception import ConfigurationError from idpyoidc.server.session.grant_manager import GrantManager from idpyoidc.util import rndstr - from .database import Database -from ..exception import InvalidBranchID from .grant import Grant from .grant import SessionToken from .info import ClientSessionInfo from .info import UserSessionInfo -from ..token import handler +from ..exception import InvalidBranchID from ..token import UnknownToken from ..token import WrongTokenClass +from ..token import handler from ..token.handler import TokenHandler logger = logging.getLogger(__name__) @@ -84,6 +83,7 @@ def ephemeral_id(*args, **kwargs): class SessionManager(GrantManager): parameter = Database.parameter.copy() + # parameter.update({"salt": ""}) init_args = ["token_handler_args", "upstream_get"] @@ -437,30 +437,6 @@ def revoke_grant(self, session_id: str): """ self._revoke_tree(self.get_grant(session_id)) - # def grants( - # self, - # session_id: Optional[str] = "", - # user_id: Optional[str] = "", - # client_id: Optional[str] = "", - # ) -> List[Grant]: - # """ - # Find all grant connected to a user session - # - # :param client_id: - # :param user_id: - # :param session_id: A session identifier - # :return: A list of grants - # """ - # if session_id: - # user_id, client_id, _ = self.decrypt_session_id(session_id) - # elif user_id and client_id: - # pass - # else: - # raise AttributeError("Must have session_id or user_id and client_id") - # - # _csi = self.get([user_id, client_id]) - # return [self.get([user_id, client_id, gid]) for gid in _csi.subordinate] - def get_session_info( self, session_id: str, @@ -488,7 +464,7 @@ def get_session_info( # Log the exception if needed logging.error(f"InvalidBranchID error: {str(e)}") raise - + if authentication_event: res["authentication_event"] = res["grant"].authentication_event @@ -551,5 +527,18 @@ def encrypted_session_id(self, *args): def unpack_session_key(self, key): return self.unpack_branch_key(key) -# def create_session_manager(upstream_get, token_handler_args, sub_func=None, conf=None): -# return SessionManager(token_handler_args, sub_func=sub_func, conf=conf, upstream_get=upstream_get) + # def create_session_manager(upstream_get, token_handler_args, sub_func=None, conf=None): + # return SessionManager(token_handler_args, sub_func=sub_func, conf=conf, + # upstream_get=upstream_get) + def get_client_id_from_token(self, token_value: str, handler_key: Optional[str] = ""): + if handler_key: + _token_info = self.token_handler.handler[handler_key].info(token_value) + else: + _token_info = self.token_handler.info(token_value) + + sid = _token_info.get("sid") + _path = self.decrypt_branch_id(sid) + if len(_path) == 3: + return _path[1] + else: + return _path[-1] diff --git a/src/idpyoidc/server/token/__init__.py b/src/idpyoidc/server/token/__init__.py index 8c92e562..42041e3d 100755 --- a/src/idpyoidc/server/token/__init__.py +++ b/src/idpyoidc/server/token/__init__.py @@ -105,23 +105,40 @@ def __call__( else: token_class = "authorization_code" + logger.debug(f"Mint {token_class}") + logger.debug(f"crypt.key: {self.crypt.key}") + _jwks = self.crypt_config.get('jwks', None) + logger.debug(f"crypt.jwks: {_jwks}") + if self.lifetime >= 0: exp = str(utc_time_sans_frac() + self.lifetime) else: - exp = "-1" # Live for ever + exp = "-1" # Live forever tmp = "" rnd = "" while rnd == tmp: # Don't use the same random value again rnd = rndstr(32) # Ultimate length multiple of 16 - return base64.b64encode( + _args = { + "rnd": rnd, + "token_class": token_class, + "session_id": session_id, + "exp": exp + } + logger.debug(f"Encrypt arguments: {_args}") + _value = base64.urlsafe_b64encode( self.crypt.encrypt(lv_pack(rnd, token_class, session_id, exp).encode()) ).decode("utf-8") + logger.debug(f"Token: {_value}") + return _value + def split_token(self, token): + logger.debug(f"split_token: {token}") + logger.debug(f"crypt key: {self.crypt.key}") try: - plain = self.crypt.decrypt(base64.b64decode(token)) + plain = self.crypt.decrypt(base64.urlsafe_b64decode(token)) except Exception as err: raise UnknownToken(err) # order: rnd, type, sid diff --git a/src/idpyoidc/server/token/handler.py b/src/idpyoidc/server/token/handler.py index 8fa90631..06f4bd3a 100755 --- a/src/idpyoidc/server/token/handler.py +++ b/src/idpyoidc/server/token/handler.py @@ -137,6 +137,9 @@ def default_token(spec): else: return False +def key_types(keys): + return [k["kid"] for k in keys] + JWKS_FILE = "private/token_jwks.json" @@ -192,10 +195,19 @@ def factory( ("token", token, "access_token"), ("refresh", refresh, "refresh_token"), ]: - if cnf is not None: - if default_token(cnf): - if kj: - _add_passwd(kj, cnf, cls) + if cnf is not None: # else just default + try: + _key_types = key_types( + cnf["kwargs"]["crypt_conf"]["kwargs"]["keys"]["key_defs"]) + except KeyError: # will fail on keys if it fails + pass + else: + if "key" in _key_types and "password" in _key_types: + raise ValueError("You have to chose one of key or password") + if "password" not in _key_types and "key" not in _key_types: + if kj: + _add_passwd(kj, cnf, cls) + logger.debug(f"init_token_handler: {cls}") args[attr] = init_token_handler(upstream_get, cnf, token_class_map[cls]) if id_token is not None: diff --git a/src/idpyoidc/server/token/jwt_token.py b/src/idpyoidc/server/token/jwt_token.py index abbfa97f..23a64dac 100644 --- a/src/idpyoidc/server/token/jwt_token.py +++ b/src/idpyoidc/server/token/jwt_token.py @@ -1,3 +1,4 @@ +import logging from typing import Callable from typing import Optional from typing import Union @@ -16,6 +17,7 @@ from .exception import UnknownToken from .exception import WrongTokenClass +logger = logging.getLogger(__name__) class JWTToken(Token): def __init__( @@ -89,8 +91,12 @@ def __call__( lifetime = usage_rules.get("expires_in") else: lifetime = self.lifetime + + _keyjar = self.upstream_get("attribute", "keyjar") + logger.info(f"Key owners in the keyjar: {_keyjar.owners()}") + signer = JWT( - key_jar=self.upstream_get("attribute", "keyjar"), + key_jar=_keyjar, iss=self.issuer, lifetime=lifetime, sign_alg=self.alg, diff --git a/src/idpyoidc/storage/abfile.py b/src/idpyoidc/storage/abfile.py index 6257fe21..d1c088f8 100644 --- a/src/idpyoidc/storage/abfile.py +++ b/src/idpyoidc/storage/abfile.py @@ -191,7 +191,7 @@ def is_changed(self, item): else: return False else: - logger.error("Could not access {}".format(fname)) + logger.error(f"Not a file '{fname}'") raise KeyError(item) def _read_info(self, fname): @@ -239,6 +239,14 @@ def synch(self): else: self.fmtime[f] = mtime + _keys = self.storage.keys() + for f in _keys: + fname = os.path.join(self.fdir, f) + if os.path.isfile(fname): + pass + else: + del self.storage[f] + def items(self): """ Implements the dict.items() method diff --git a/src/idpyoidc/storage/abfile_no_cache.py b/src/idpyoidc/storage/abfile_no_cache.py new file mode 100644 index 00000000..7d1d5c3b --- /dev/null +++ b/src/idpyoidc/storage/abfile_no_cache.py @@ -0,0 +1,210 @@ +import logging +import os +import time +from typing import Optional + +from cryptojwt.utils import importer +from filelock import FileLock + +from idpyoidc.storage import DictType +from idpyoidc.util import PassThru +from idpyoidc.util import QPKey + +logger = logging.getLogger(__name__) + + +class AbstractFileSystemNoCache(DictType): + """ + FileSystem implements a simple file based database. + It has a dictionary like interface. + Each key maps one-to-one to a file on disc, where the content of the + file is the value. + ONLY goes one level deep. + Not directories in directories. + """ + + def __init__( + self, + fdir: Optional[str] = "", + key_conv: Optional[str] = "", + value_conv: Optional[str] = "", + read_only: Optional[bool] = False, + **kwargs + ): + """ + items = FileSystem( + { + 'fdir': fdir, + 'key_conv':{'to': quote_plus, 'from': unquote_plus}, + 'value_conv':{'to': keyjar_to_jwks, 'from': jwks_to_keyjar} + }) + + :param fdir: The root of the directory + :param key_conv: Converts to/from the key displayed by this class to + users of it to something that can be used as a file name. + The value of key_conv is a class that has the methods 'serialize'/'deserialize'. + :param value_conv: As with key_conv you can convert/translate + the value bound to a key in the database to something that can easily + be stored in a file. Like with key_conv the value of this parameter + is a class that has the methods 'serialize'/'deserialize'. + """ + super(AbstractFileSystemNoCache, self).__init__( + fdir=fdir, key_conv=key_conv, value_conv=value_conv + ) + + self.fdir = fdir + self.read_only = read_only + + if key_conv: + self.key_conv = importer(key_conv)() + else: + self.key_conv = QPKey() + + if value_conv: + self.value_conv = importer(value_conv)() + else: + self.value_conv = PassThru() + + if not os.path.isdir(self.fdir): + os.makedirs(self.fdir) + + def get(self, item, default=None): + try: + return self[item] + except KeyError: + return default + + def __getitem__(self, item): + """ + Return the value bound to an identifier. + + :param item: The identifier. + :return: + """ + _file_name = self.key_conv.serialize(item) + logger.debug(f'Read from "{_file_name}"') + return self._read_info(_file_name) + + def __setitem__(self, key, value): + """ + Binds a value to a specific key. If the file that the key maps to + does not exist it will be created. The content of the file will be + set to the value given. + + :param key: Identifier + :param value: Value that should be bound to the identifier. + :return: + """ + + if self.read_only: + return + + if not os.path.isdir(self.fdir): + os.makedirs(self.fdir, exist_ok=True) + + try: + _file_name = self.key_conv.serialize(key) + except KeyError: + _file_name = key + + fname = os.path.join(self.fdir, _file_name) + lock = FileLock(f"{fname}.lock") + with lock: + with open(fname, "w") as fp: + fp.write(self.value_conv.serialize(value)) + + logger.debug(f'Wrote to "{_file_name}"') + + def __delitem__(self, key): + if self.read_only: + return + + fname = os.path.join(self.fdir, key) + if fname.endswith(".lock"): + if os.path.isfile(fname): + os.unlink(fname) + else: + if os.path.isfile(fname): + lock = FileLock(f"{fname}.lock") + with lock: + os.unlink(fname) + os.unlink(f"{fname}.lock") + def _keys(self): + """ + Implements the dict.keys() method + """ + keys = [] + for f in os.listdir(self.fdir): + fname = os.path.join(self.fdir, f) + + if not os.path.isfile(fname): + continue + if fname.endswith(".lock"): + continue + + keys.append(f) + + return keys + + def keys(self): + return [self.key_conv.deserialize(k) for k in self._keys()] + + def _read_info(self, key): + file_name = os.path.join(self.fdir, key) + if os.path.isfile(file_name): + try: + lock = FileLock(f"{file_name}.lock") + with lock: + info = open(file_name, "r").read().strip() + lock.release() + return self.value_conv.deserialize(info) + except Exception as err: + logger.error(err) + raise + else: + _msg = f"No such file: '{file_name}'" + logger.error(_msg) + return None + + def items(self): + """ + Implements the dict.items() method + """ + for k in self._keys(): + v = self._read_info(k) + yield self.key_conv.deserialize(k), v + + def clear(self): + """ + Completely resets the database. This means that all information in + the local cache and on disc will be erased. + """ + if self.read_only: + return + + if not os.path.isdir(self.fdir): + os.makedirs(self.fdir, exist_ok=True) + return + + for f in os.listdir(self.fdir): + del self[f] + + def __contains__(self, item): + file_name = os.path.join(self.fdir, self.key_conv.serialize(item)) + if os.path.isfile(file_name): + return True + else: + return False + + def __iter__(self): + for k in self._keys(): + yield self.key_conv.deserialize(k) + + def __call__(self, *args, **kwargs): + return [self.key_conv.deserialize(k) for k in self._keys()] + + def __len__(self): + if not os.path.isdir(self.fdir): + return 0 + + return len(self._keys()) \ No newline at end of file diff --git a/src/idpyoidc/storage/listfile.py b/src/idpyoidc/storage/listfile.py index 77520de3..b2c2895f 100644 --- a/src/idpyoidc/storage/listfile.py +++ b/src/idpyoidc/storage/listfile.py @@ -111,6 +111,49 @@ def __getitem__(self, item): else: return None + def __len__(self): + _lst = self._read_info(self.file_name) + + if _lst is None or _lst == []: + return 0 + return len(set(_lst)) + + def _read_info(self, fname): + if os.path.isfile(fname): + try: + lock = FileLock(f"{fname}.lock") + + with lock: + fp = open(fname, "r") + info = [x.strip() for x in fp.readlines()] + lock.release() + return list(set(info)) + except Exception as err: + logger.error(err) + raise + else: + _msg = f"No such file: '{fname}'" + logger.error(_msg) + return None + + def __call__(self): + return self._read_info(self.file_name) + + def list(self): + return self._read_info(self.file_name) + +class ReadWriteListFile(object): + def __init__(self, file_name): + self.file_name = file_name + + if not os.path.exists(file_name): + fp = open(file_name, "x") + fp.close() + + def __contains__(self, item): + _lst = self._read_info(self.file_name) + return item in _lst + def __len__(self): _lst = self._read_info(self.file_name) if _lst is None or _lst == []: diff --git a/src/idpyoidc/transform.py b/src/idpyoidc/transform.py index 3834006c..7a9b5593 100644 --- a/src/idpyoidc/transform.py +++ b/src/idpyoidc/transform.py @@ -1,6 +1,7 @@ import logging from typing import Optional +from idpyoidc.message import Message from idpyoidc.message.oidc import RegistrationRequest from idpyoidc.message.oidc import RegistrationResponse @@ -204,9 +205,10 @@ def preferred_to_registered( return registered -def create_registration_request(prefers: dict, supported: dict) -> dict: +def create_registration_request(prefers: dict, supported: dict, + registration_class: Optional[Message] = RegistrationRequest) -> dict: _request = {} - for key, spec in RegistrationRequest.c_param.items(): + for key, spec in registration_class.c_param.items(): _pref_key = REGISTER2PREFERRED.get(key, key) if _pref_key in prefers: value = prefers[_pref_key] @@ -221,7 +223,7 @@ def create_registration_request(prefers: dict, supported: dict) -> dict: _request[key] = array_or_singleton(spec, value) for key, val in prefers.items(): - if key not in RegistrationRequest.c_param.keys(): + if key not in registration_class.c_param.keys(): if key not in REGISTER2PREFERRED.values(): _request[key] = val diff --git a/tests/private/token_jwks.json b/tests/private/token_jwks.json index d3e0f070..9a4ae9d6 100644 --- a/tests/private/token_jwks.json +++ b/tests/private/token_jwks.json @@ -1 +1 @@ -{"keys": [{"kty": "oct", "use": "enc", "kid": "code", "k": "vSHDkLBHhDStkR0NWu8519rmV5zmnm5_"}, {"kty": "oct", "use": "enc", "kid": "refresh", "k": "vrjoMrmgK8SmJJPc318zTxqG_tvBqF5l"}]} \ No newline at end of file +{"keys": [{"kty": "oct", "use": "enc", "kid": "code", "k": "vSHDkLBHhDStkR0NWu8519rmV5zmnm5_"}, {"kty": "oct", "use": "enc", "kid": "refresh", "k": "ZcwEWWiviH92lCBx0NCAtZIHbK22je6S"}]} \ No newline at end of file diff --git a/tests/test_08_transform.py b/tests/test_08_transform.py index 71c83d9b..a3c2193f 100644 --- a/tests/test_08_transform.py +++ b/tests/test_08_transform.py @@ -4,9 +4,9 @@ from cryptojwt.utils import importer from idpyoidc.client.claims.oidc import Claims as OIDC_Claims -from idpyoidc.client.claims.transform import create_registration_request -from idpyoidc.client.claims.transform import preferred_to_registered -from idpyoidc.client.claims.transform import supported_to_preferred +from idpyoidc.transform import create_registration_request +from idpyoidc.transform import preferred_to_registered +from idpyoidc.transform import supported_to_preferred from idpyoidc.message.oidc import APPLICATION_TYPE_WEB from idpyoidc.message.oidc import ProviderConfigurationResponse from idpyoidc.message.oidc import RegistrationRequest @@ -248,8 +248,8 @@ def test_provider_info(self): assert set(claims.prefer.keys()) == { "application_type", "default_max_age", - "encrypt_request_object_supported", - "encrypt_userinfo_supported", + # "encrypt_request_object_supported", + # "encrypt_userinfo_supported", "id_token_encryption_alg_values_supported", "id_token_encryption_enc_values_supported", "id_token_signing_alg_values_supported", @@ -362,6 +362,8 @@ def test_registration_response(self): "client_name", "contacts", "default_max_age", + "encrypt_request_object_supported", + "encrypt_userinfo_supported", "id_token_signed_response_alg", "logo_uri", "redirect_uris", diff --git a/tests/test_09_work_condition.py b/tests/test_09_work_condition.py index 957d8570..abbd77d2 100644 --- a/tests/test_09_work_condition.py +++ b/tests/test_09_work_condition.py @@ -4,9 +4,9 @@ from cryptojwt.utils import importer from idpyoidc.client.claims.oidc import Claims -from idpyoidc.client.claims.transform import create_registration_request -from idpyoidc.client.claims.transform import preferred_to_registered -from idpyoidc.client.claims.transform import supported_to_preferred +from idpyoidc.transform import create_registration_request +from idpyoidc.transform import preferred_to_registered +from idpyoidc.transform import supported_to_preferred from idpyoidc.message.oidc import APPLICATION_TYPE_WEB KEYSPEC = [ @@ -179,9 +179,13 @@ def test_registration_response(self): assert set(registration_request.keys()) == { "application_type", + "client_id", "client_name", + "client_secret", "contacts", "default_max_age", + "encrypt_request_object_supported", + "encrypt_userinfo_supported", "id_token_signed_response_alg", "jwks", "logo_uri", @@ -318,8 +322,13 @@ def test_registration_response_consistence(self): assert set(registration_request.keys()) == { "application_type", "client_name", + "client_id", + "client_name", + "client_secret", "contacts", "default_max_age", + "encrypt_request_object_supported", + "encrypt_userinfo_supported", "id_token_signed_response_alg", "jwks", "logo_uri", diff --git a/tests/test_14_read_only_list_file.py b/tests/test_14_read_only_list_file.py index 2abdf9e9..141501ef 100644 --- a/tests/test_14_read_only_list_file.py +++ b/tests/test_14_read_only_list_file.py @@ -25,5 +25,4 @@ def test_read_only_list_file(): # sleep(2) # assert _read_only.is_changed(FILE_NAME) is True - assert set(_read_only) == {"one", "two", "three"} - assert _read_only[-1] == "three" \ No newline at end of file + assert set(_read_only.list()) == {"one", "two", "three"} diff --git a/tests/test_client_05_util.py b/tests/test_client_05_util.py index 3a22416a..d3088be6 100644 --- a/tests/test_client_05_util.py +++ b/tests/test_client_05_util.py @@ -141,7 +141,7 @@ def test_get_deserialization_method_json(): resp = FakeResponse("application/json") assert get_deserialization_method(resp) == "json" - resp = FakeResponse("application/json; charset=utf-8") + resp = FakeResponse("application/json") assert get_deserialization_method(resp) == "json" resp.headers["content-type"] = "application/jrd+json" diff --git a/tests/test_client_16_util.py b/tests/test_client_16_util.py index a09d65a5..0fdeede0 100644 --- a/tests/test_client_16_util.py +++ b/tests/test_client_16_util.py @@ -147,7 +147,7 @@ def test_get_deserialization_method_json(): resp = FakeResponse("application/json") assert get_deserialization_method(resp) == "json" - resp = FakeResponse("application/json; charset=utf-8") + resp = FakeResponse("application/json") assert get_deserialization_method(resp) == "json" resp.headers["content-type"] = "application/jrd+json" diff --git a/tests/test_client_21_oidc_service.py b/tests/test_client_21_oidc_service.py index f0d83006..6f242dff 100644 --- a/tests/test_client_21_oidc_service.py +++ b/tests/test_client_21_oidc_service.py @@ -905,16 +905,12 @@ def test_construct(self): assert isinstance(_req, RegistrationRequest) assert set(_req.keys()) == { "application_type", - 'callback_uris', "default_max_age", - 'encrypt_request_object_supported', - 'encrypt_userinfo_supported', "grant_types", "id_token_signed_response_alg", "jwks", "redirect_uris", "request_object_signing_alg", - 'requests_dir', "response_modes", "response_types", "subject_type", @@ -932,17 +928,13 @@ def test_config_with_post_logout(self): assert isinstance(_req, RegistrationRequest) assert set(_req.keys()) == { "application_type", - 'callback_uris', "default_max_age", - 'encrypt_request_object_supported', - 'encrypt_userinfo_supported', "grant_types", "id_token_signed_response_alg", "jwks", "post_logout_redirect_uri", "redirect_uris", "request_object_signing_alg", - 'requests_dir', "response_modes", "response_types", "subject_type", @@ -979,20 +971,13 @@ def test_config_with_required_request_uri(): _req = reg_service.construct() assert isinstance(_req, RegistrationRequest) assert set(_req.keys()) == {'application_type', - 'callback_uris', - 'client_id', - 'client_secret', 'default_max_age', - 'encrypt_request_object_supported', - 'encrypt_userinfo_supported', 'grant_types', 'id_token_signed_response_alg', 'jwks', 'redirect_uris', 'request_object_signing_alg', - 'request_parameter', 'request_uris', - 'requests_dir', 'response_modes', 'response_types', 'subject_type', @@ -1039,20 +1024,13 @@ def test_config_logout_uri(): _req = reg_service.construct() assert isinstance(_req, RegistrationRequest) assert set(_req.keys()) == {'application_type', - 'callback_uris', - 'client_id', - 'client_secret', 'default_max_age', - 'encrypt_request_object_supported', - 'encrypt_userinfo_supported', 'grant_types', 'id_token_signed_response_alg', 'jwks', 'redirect_uris', 'request_object_signing_alg', - 'request_parameter', 'request_uris', - 'requests_dir', 'response_modes', 'response_types', 'subject_type', diff --git a/tests/test_client_24_oic_utils.py b/tests/test_client_24_oic_utils.py index 4e799803..d6c42425 100644 --- a/tests/test_client_24_oic_utils.py +++ b/tests/test_client_24_oic_utils.py @@ -1,9 +1,9 @@ from cryptojwt.jwe.jwe import factory from cryptojwt.key_jar import build_keyjar -from idpyoidc.client.oidc.utils import construct_request_uri -from idpyoidc.client.oidc.utils import request_object_encryption +from idpyoidc.client.request_object import request_object_encryption from idpyoidc.client.service_context import ServiceContext +from idpyoidc.client.util import construct_request_uri from idpyoidc.message.oidc import AuthorizationRequest KEYSPEC = [ diff --git a/tests/test_client_27_conversation.py b/tests/test_client_27_conversation.py index 99264876..effe9553 100644 --- a/tests/test_client_27_conversation.py +++ b/tests/test_client_27_conversation.py @@ -401,11 +401,8 @@ def test_conversation(): "application_type", "backchannel_logout_session_required", "backchannel_logout_uri", - 'callback_uris', "contacts", "default_max_age", - 'encrypt_request_object_supported', - 'encrypt_userinfo_supported', "grant_types", "id_token_signed_response_alg", "jwks", diff --git a/tests/test_client_41_rp_handler_persistent.py b/tests/test_client_41_rp_handler_persistent.py index 7cfb38e3..0a6eaf1a 100644 --- a/tests/test_client_41_rp_handler_persistent.py +++ b/tests/test_client_41_rp_handler_persistent.py @@ -292,7 +292,7 @@ def test_begin(self): assert query["client_id"] == ["eeeeeeeee"] assert query["redirect_uri"] == ["https://example.com/rp/authz_cb/github"] assert query["response_type"] == ["code"] - assert query["scope"] == ["user public_repo openid"] + assert query["scope"] == ["openid"] def test_get_session_information(self): rph_1 = RPHandler( diff --git a/tests/test_server_16_endpoint.py b/tests/test_server_16_endpoint.py index 5a3b59de..7c023993 100755 --- a/tests/test_server_16_endpoint.py +++ b/tests/test_server_16_endpoint.py @@ -209,7 +209,7 @@ def test_do_response_response_msg_1(self): def test_do_response_placement_body(self): self.endpoint.response_placement = "body" info = self.endpoint.do_response(EXAMPLE_MSG) - assert ("Content-type", "application/json; charset=utf-8") in info["http_headers"] + assert ("Content-type", "application/json") in info["http_headers"] assert ( info["response"] == '{"name": "Doe, Jane", "given_name": "Jane", "family_name": ' '"Doe"}' @@ -217,6 +217,7 @@ def test_do_response_placement_body(self): def test_do_response_placement_url(self): self.endpoint.response_placement = "url" + self.endpoint.response_format = "urlencoded" info = self.endpoint.do_response(EXAMPLE_MSG, return_uri="https://example.org/cb") assert ("Content-type", "application/x-www-form-urlencoded") in info["http_headers"] assert ( diff --git a/tests/test_server_16_endpoint_context.py b/tests/test_server_16_endpoint_context.py index f96b676c..4cb390a3 100644 --- a/tests/test_server_16_endpoint_context.py +++ b/tests/test_server_16_endpoint_context.py @@ -5,17 +5,13 @@ from cryptojwt.key_jar import build_keyjar from idpyoidc import alg_info -from idpyoidc import metadata from idpyoidc.server import OPConfiguration from idpyoidc.server import Server from idpyoidc.server.endpoint import Endpoint -from idpyoidc.server.exception import OidcEndpointError from idpyoidc.server.user_authn.authn_context import INTERNETPROTOCOLPASSWORD -from idpyoidc.server.util import allow_refresh_token - from . import CRYPT_CONFIG -from . import SESSION_PARAMS from . import full_path +from . import SESSION_PARAMS KEYDEFS = [ {"type": "RSA", "key": "", "use": ["sig"]}, @@ -83,6 +79,7 @@ class Endpoint_1(Endpoint): class TestEndpointContext: + @pytest.fixture(autouse=True) def create_endpoint_context(self): server = Server(conf) diff --git a/tests/test_server_22_oidc_provider_config_endpoint.py b/tests/test_server_22_oidc_provider_config_endpoint.py index 7000d724..723a4a49 100755 --- a/tests/test_server_22_oidc_provider_config_endpoint.py +++ b/tests/test_server_22_oidc_provider_config_endpoint.py @@ -92,7 +92,7 @@ def test_do_response(self): assert _msg["token_endpoint"] == "https://example.com/token" assert _msg["jwks_uri"] == "https://example.com/static/jwks.json" assert "claims_supported" not in _msg # No default for this - assert ("Content-type", "application/json; charset=utf-8") in msg["http_headers"] + assert ("Content-type", "application/json") in msg["http_headers"] def test_scopes_supported(self, conf): scopes_supported = ["openid", "random", "profile"] diff --git a/tests/test_server_31_oauth2_introspection.py b/tests/test_server_31_oauth2_introspection.py index bf283476..1ad30e1c 100644 --- a/tests/test_server_31_oauth2_introspection.py +++ b/tests/test_server_31_oauth2_introspection.py @@ -325,7 +325,7 @@ def test_do_response(self): assert isinstance(msg_info, dict) assert set(msg_info.keys()) == {"response", "http_headers"} assert msg_info["http_headers"] == [ - ("Content-type", "application/json; charset=utf-8"), + ("Content-type", "application/json"), ("Pragma", "no-cache"), ("Cache-Control", "no-store"), ] diff --git a/tests/test_server_32_oidc_read_registration.py b/tests/test_server_32_oidc_read_registration.py index e09bc5cd..2dea6413 100644 --- a/tests/test_server_32_oidc_read_registration.py +++ b/tests/test_server_32_oidc_read_registration.py @@ -160,4 +160,4 @@ def test_do_response(self): _endp_response = self.registration_api_endpoint.do_response(_info) assert set(_endp_response.keys()) == {"response", "http_headers"} - assert ("Content-type", "application/json; charset=utf-8") in _endp_response["http_headers"] + assert ("Content-type", "application/json") in _endp_response["http_headers"] diff --git a/tests/test_server_40_oauth2_pushed_authorization.py b/tests/test_server_40_oauth2_pushed_authorization.py index 4d7ea6da..7cf6bb38 100644 --- a/tests/test_server_40_oauth2_pushed_authorization.py +++ b/tests/test_server_40_oauth2_pushed_authorization.py @@ -251,7 +251,7 @@ def test_pushed_auth_urlencoded_process(self): # And now for the authorization request with the OP provided request_uri - _msg["request_uri"] = _resp["http_response"]["request_uri"] + _msg["request_uri"] = _resp["response_args"]["request_uri"] for parameter in ["code_challenge", "code_challenge_method"]: del _msg[parameter] diff --git a/tests/test_server_60_dpop.py b/tests/test_server_60_dpop.py index e13b8a35..44a88588 100644 --- a/tests/test_server_60_dpop.py +++ b/tests/test_server_60_dpop.py @@ -115,7 +115,10 @@ def create_endpoint(self): "add_on": { "dpop": { "function": "idpyoidc.server.oauth2.add_on.dpop.add_support", - "kwargs": {"dpop_signing_alg_values_supported": ["ES256"]}, + "kwargs": { + "dpop_signing_alg_values_supported": ["ES256"], + "dpop_endpoints": ["token"] + }, }, }, "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, diff --git a/tests/test_tandem_oauth2_add_on.py b/tests/test_tandem_oauth2_add_on.py index a3776fc8..58abb2d2 100644 --- a/tests/test_tandem_oauth2_add_on.py +++ b/tests/test_tandem_oauth2_add_on.py @@ -3,6 +3,7 @@ from typing import List from cryptojwt.key_jar import build_keyjar +from idpyoidc.key_import import store_under_other_id from idpyoidc.client.oauth2 import Client from idpyoidc.message.oauth2 import is_error_message @@ -324,10 +325,13 @@ def test_jar(): }, } + _keyjar = build_keyjar(KEYDEFS) + _keyjar = store_under_other_id(_keyjar, "", client_config["client_id"], True) + client = Client( client_type="oauth2", config=client_config, - keyjar=build_keyjar(KEYDEFS), + keyjar=_keyjar, services=_OAUTH2_SERVICES, ) diff --git a/tests/test_tandem_oauth2_par_service.py b/tests/test_tandem_oauth2_par_service.py new file mode 100644 index 00000000..9630a55a --- /dev/null +++ b/tests/test_tandem_oauth2_par_service.py @@ -0,0 +1,285 @@ +import json +import os + +import pytest +from cryptojwt.key_jar import build_keyjar + +from idpyoidc.client.oauth2 import Client +from idpyoidc.key_import import import_jwks +from idpyoidc.message.oauth2 import is_error_message +from idpyoidc.message.oidc import AccessTokenRequest +from idpyoidc.message.oidc import AuthorizationRequest +from idpyoidc.message.oidc import RefreshAccessTokenRequest +from idpyoidc.server import Server +from idpyoidc.server.authz import AuthzHandling +from idpyoidc.server.client_authn import verify_client +from idpyoidc.server.configure import ASConfiguration +from idpyoidc.server.user_authn.authn_context import INTERNETPROTOCOLPASSWORD +from idpyoidc.server.user_info import UserInfo +from idpyoidc.util import rndstr +from tests import CRYPT_CONFIG +from tests import SESSION_PARAMS + +KEYDEFS = [ + {"type": "RSA", "key": "", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +CLIENT_KEYJAR = build_keyjar(KEYDEFS) + +COOKIE_KEYDEFS = [ + {"type": "oct", "kid": "sig", "use": ["sig"]}, + {"type": "oct", "kid": "enc", "use": ["enc"]}, +] + +AUTH_REQ = AuthorizationRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + scope=["openid"], + state="STATE", + response_type="code", +) + +TOKEN_REQ = AccessTokenRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + state="STATE", + grant_type="authorization_code", + client_secret="hemligt", +) + +REFRESH_TOKEN_REQ = RefreshAccessTokenRequest( + grant_type="refresh_token", client_id="https://example.com/", client_secret="hemligt" +) + +TOKEN_REQ_DICT = TOKEN_REQ.to_dict() + +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +USERINFO = UserInfo(json.loads(open(full_path("users.json")).read())) + +_OAUTH2_SERVICES = { + "metadata": {"class": "idpyoidc.client.oauth2.server_metadata.ServerMetadata"}, + "authorization": {"class": "idpyoidc.client.oauth2.authorization.Authorization"}, + "pushed_authorization": {"class": "idpyoidc.client.oauth2.pushed_authorization.PushedAuthorization"}, + "access_token": {"class": "idpyoidc.client.oauth2.access_token.AccessToken"}, + "resource": {"class": "idpyoidc.client.oauth2.resource.Resource"}, +} + + +class TestFlow(object): + + @pytest.fixture(autouse=True) + def create_entities(self): + server_conf = { + "issuer": "https://example.com/", + "httpc_params": {"verify": False, "timeout": 1}, + "subject_types_supported": ["public", "pairwise", "ephemeral"], + "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, + "endpoint": { + "metadata": { + "path": ".well-known/oauth-authorization-server", + "class": "idpyoidc.server.oauth2.server_metadata.ServerMetadata", + "kwargs": {}, + }, + "authorization": { + "path": "authorization", + "class": "idpyoidc.server.oauth2.authorization.Authorization", + "kwargs": {}, + }, + "pushed_authorization": { + "path": "par", + "class": "idpyoidc.server.oauth2.pushed_authorization.PushedAuthorization", + "kwargs": {}, + }, + "token": { + "path": "token", + "class": "idpyoidc.server.oauth2.token.Token", + "kwargs": {}, + }, + }, + "authentication": { + "anon": { + "acr": INTERNETPROTOCOLPASSWORD, + "class": "idpyoidc.server.user_authn.user.NoAuthn", + "kwargs": {"user": "diana"}, + } + }, + "userinfo": {"class": UserInfo, "kwargs": {"db": {}}}, + "client_authn": verify_client, + "authz": { + "class": AuthzHandling, + "kwargs": { + "grant_config": { + "usage_rules": { + "authorization_code": { + "supports_minting": ["access_token", "refresh_token"], + "max_usage": 1, + }, + "access_token": { + "supports_minting": ["access_token", "refresh_token"], + "expires_in": 600, + }, + "refresh_token": { + "supports_minting": ["access_token"], + "audience": ["https://example.com", "https://example2.com"], + "expires_in": 43200, + }, + }, + "expires_in": 43200, + } + }, + }, + "token_handler_args": { + "jwks_file": "private/token_jwks.json", + "code": {"lifetime": 600, "kwargs": {"crypt_conf": CRYPT_CONFIG}}, + "token": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "add_claims_by_scope": True, + "aud": ["https://example.org/appl"], + }, + }, + "refresh": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "aud": ["https://example.org/appl"], + }, + }, + }, + "session_params": SESSION_PARAMS, + } + self.server = Server(ASConfiguration(conf=server_conf, base_path=BASEDIR), cwd=BASEDIR) + + client_1_config = { + "issuer": server_conf["issuer"], + "client_secret": "hemligtlösenord", + "client_id": "client_1", + "redirect_uris": ["https://example.com/cb"], + "client_salt": "salted_peanuts_cooking", + "token_endpoint_auth_methods_supported": ["client_secret_post"], + "response_types_supported": ["code"], + } + client_services = _OAUTH2_SERVICES + self.client = Client( + client_type="oauth2", + config=client_1_config, + keyjar=build_keyjar(KEYDEFS), + services=_OAUTH2_SERVICES, + ) + + self.context = self.server.context + self.context.cdb["client_1"] = client_1_config + self.context.keyjar = import_jwks(self.context.keyjar, self.client.keyjar.export_jwks(), "client_1") + + self.context.set_provider_info() + self.session_manager = self.context.session_manager + self.user_id = "diana" + + def do_query(self, service_type, endpoint_type, request_args, state): + _client_service = self.client.get_service(service_type) + req_info = _client_service.get_request_parameters(request_args=request_args, state=state) + + areq = req_info.get("request") + headers = req_info.get("headers") + + _server_endpoint = self.server.get_endpoint(endpoint_type) + if areq: + if headers: + argv = {"http_info": {"headers": headers}} + else: + argv = {} + areq.lax = True + _req = areq.serialize(_server_endpoint.request_format) + _pr_resp = _server_endpoint.parse_request(_req, **argv) + else: + _pr_resp = _server_endpoint.parse_request(areq) + + if is_error_message(_pr_resp): + return areq, _pr_resp + + _resp = _server_endpoint.process_request(_pr_resp) + if is_error_message(_resp): + return areq, _resp + + _response = _server_endpoint.do_response(**_resp) + + resp = _client_service.parse_response(_response["response"]) + _client_service.update_service_context(_resp["response_args"], key=state) + return areq, resp + + def process_setup(self, token=None, scope=None): + # ***** Discovery ********* + + _req, _resp = self.do_query("server_metadata", "server_metadata", {}, "") + + # ***** Pushed Authorization Request ********** + _nonce = (rndstr(24),) + _context = self.client.get_service_context() + # Need a new state for a new authorization request + _state = _context.cstate.create_state(iss=_context.get("issuer")) + _context.cstate.bind_key(_nonce, _state) + + req_args = {"response_type": ["code"], "nonce": _nonce, "state": _state} + + if scope: + _scope = scope + else: + _scope = ["openid"] + + if token and list(token.keys())[0] == "refresh_token": + _scope = ["openid", "offline_access"] + + req_args["scope"] = _scope + + areq, auth_response = self.do_query("pushed_authorization", + "pushed_authorization", + req_args, + _state) + + # ***** Authorization Request ********** + _context = self.client.get_service_context() + + req_args = {"request_uri": auth_response["request_uri"], "response_type": ["code"]} + + areq, auth_response = self.do_query("authorization", "authorization", req_args, _state) + + # ***** Token Request ********** + + req_args = { + "code": auth_response["code"], + "state": auth_response["state"], + "redirect_uri": areq["redirect_uri"], + "grant_type": "authorization_code", + "client_id": self.client.get_client_id(), + "client_secret": _context.get_usage("client_secret"), + } + + _token_request, resp = self.do_query("accesstoken", "token", req_args, _state) + + return resp, _state, _scope + + def test_flow(self): + """ + Test that token exchange requests work correctly + """ + + resp, _state, _scope = self.process_setup(token="access_token", scope=["foobar"]) + + # Construct the resource request + + _client_service = self.client.get_service("resource") + req_info = _client_service.get_request_parameters( + authn_method="bearer_header", state=_state, endpoint="https://resource.example.com" + ) + + assert req_info["url"] == "https://resource.example.com" + assert "Authorization" in req_info["headers"] + assert req_info["headers"]["Authorization"].startswith("Bearer")