diff --git a/google/auth/_default.py b/google/auth/_default.py index 1234fb25d..cf0cdd772 100644 --- a/google/auth/_default.py +++ b/google/auth/_default.py @@ -484,42 +484,8 @@ def _get_impersonated_service_account_credentials(filename, info, scopes): from google.auth import impersonated_credentials try: - source_credentials_info = info.get("source_credentials") - source_credentials_type = source_credentials_info.get("type") - if source_credentials_type == _AUTHORIZED_USER_TYPE: - source_credentials, _ = _get_authorized_user_credentials( - filename, source_credentials_info - ) - elif source_credentials_type == _SERVICE_ACCOUNT_TYPE: - source_credentials, _ = _get_service_account_credentials( - filename, source_credentials_info - ) - elif source_credentials_type == _EXTERNAL_ACCOUNT_AUTHORIZED_USER_TYPE: - source_credentials, _ = _get_external_account_authorized_user_credentials( - filename, source_credentials_info - ) - else: - raise exceptions.InvalidType( - "source credential of type {} is not supported.".format( - source_credentials_type - ) - ) - impersonation_url = info.get("service_account_impersonation_url") - start_index = impersonation_url.rfind("/") - end_index = impersonation_url.find(":generateAccessToken") - if start_index == -1 or end_index == -1 or start_index > end_index: - raise exceptions.InvalidValue( - "Cannot extract target principal from {}".format(impersonation_url) - ) - target_principal = impersonation_url[start_index + 1 : end_index] - delegates = info.get("delegates") - quota_project_id = info.get("quota_project_id") - credentials = impersonated_credentials.Credentials( - source_credentials, - target_principal, - scopes, - delegates, - quota_project_id=quota_project_id, + credentials = impersonated_credentials.Credentials.from_impersonated_service_account_info( + info, scopes=scopes ) except ValueError as caught_exc: msg = "Failed to load impersonated service account credentials from {}".format( diff --git a/google/auth/impersonated_credentials.py b/google/auth/impersonated_credentials.py index ed7e3f00b..d49998cfb 100644 --- a/google/auth/impersonated_credentials.py +++ b/google/auth/impersonated_credentials.py @@ -47,6 +47,12 @@ _GOOGLE_OAUTH2_TOKEN_ENDPOINT = "https://oauth2.googleapis.com/token" +_SOURCE_CREDENTIAL_AUTHORIZED_USER_TYPE = "authorized_user" +_SOURCE_CREDENTIAL_SERVICE_ACCOUNT_TYPE = "service_account" +_SOURCE_CREDENTIAL_EXTERNAL_ACCOUNT_AUTHORIZED_USER_TYPE = ( + "external_account_authorized_user" +) + def _make_iam_token_request( request, @@ -410,6 +416,75 @@ def with_scopes(self, scopes, default_scopes=None): cred._target_scopes = scopes or default_scopes return cred + @classmethod + def from_impersonated_service_account_info(cls, info, scopes=None): + """Creates a Credentials instance from parsed impersonated service account credentials info. + + Args: + info (Mapping[str, str]): The impersonated service account credentials info in Google + format. + scopes (Sequence[str]): Optional list of scopes to include in the + credentials. + + Returns: + google.oauth2.credentials.Credentials: The constructed + credentials. + + Raises: + InvalidType: If the info["source_credentials"] are not a supported impersonation type + InvalidValue: If the info["service_account_impersonation_url"] is not in the expected format. + ValueError: If the info is not in the expected format. + """ + + source_credentials_info = info.get("source_credentials") + source_credentials_type = source_credentials_info.get("type") + if source_credentials_type == _SOURCE_CREDENTIAL_AUTHORIZED_USER_TYPE: + from google.oauth2 import credentials + + source_credentials = credentials.Credentials.from_authorized_user_info( + source_credentials_info + ) + elif source_credentials_type == _SOURCE_CREDENTIAL_SERVICE_ACCOUNT_TYPE: + from google.oauth2 import service_account + + source_credentials = service_account.Credentials.from_service_account_info( + source_credentials_info + ) + elif ( + source_credentials_type + == _SOURCE_CREDENTIAL_EXTERNAL_ACCOUNT_AUTHORIZED_USER_TYPE + ): + from google.auth import external_account_authorized_user + + source_credentials = external_account_authorized_user.Credentials.from_info( + source_credentials_info + ) + else: + raise exceptions.InvalidType( + "source credential of type {} is not supported.".format( + source_credentials_type + ) + ) + + impersonation_url = info.get("service_account_impersonation_url") + start_index = impersonation_url.rfind("/") + end_index = impersonation_url.find(":generateAccessToken") + if start_index == -1 or end_index == -1 or start_index > end_index: + raise exceptions.InvalidValue( + "Cannot extract target principal from {}".format(impersonation_url) + ) + target_principal = impersonation_url[start_index + 1 : end_index] + delegates = info.get("delegates") + quota_project_id = info.get("quota_project_id") + + return cls( + source_credentials, + target_principal, + scopes, + delegates, + quota_project_id=quota_project_id, + ) + class IDTokenCredentials(credentials.CredentialsWithQuotaProject): """Open ID Connect ID Token-based service account credentials. diff --git a/google/oauth2/id_token.py b/google/oauth2/id_token.py index b68ab6b30..a6c51ce63 100644 --- a/google/oauth2/id_token.py +++ b/google/oauth2/id_token.py @@ -284,6 +284,18 @@ def fetch_id_token_credentials(audience, request=None): return service_account.IDTokenCredentials.from_service_account_info( info, target_audience=audience ) + elif info.get("type") == "impersonated_service_account": + from google.auth import impersonated_credentials + + target_credentials = impersonated_credentials.Credentials.from_impersonated_service_account_info( + info + ) + + return impersonated_credentials.IDTokenCredentials( + target_credentials=target_credentials, + target_audience=audience, + include_email=True, + ) except ValueError as caught_exc: new_exc = exceptions.DefaultCredentialsError( "GOOGLE_APPLICATION_CREDENTIALS is not valid service account credentials.", diff --git a/tests/oauth2/test_id_token.py b/tests/oauth2/test_id_token.py index 7d6a22481..ff3d4b6d8 100644 --- a/tests/oauth2/test_id_token.py +++ b/tests/oauth2/test_id_token.py @@ -20,6 +20,7 @@ from google.auth import environment_vars from google.auth import exceptions +from google.auth import impersonated_credentials from google.auth import transport from google.oauth2 import id_token from google.oauth2 import service_account @@ -27,6 +28,12 @@ SERVICE_ACCOUNT_FILE = os.path.join( os.path.dirname(__file__), "../data/service_account.json" ) + +IMPERSONATED_SERVICE_ACCOUNT_FILE = os.path.join( + os.path.dirname(__file__), + "../data/impersonated_service_account_authorized_user_source.json", +) + ID_TOKEN_AUDIENCE = "https://pubsub.googleapis.com" @@ -262,6 +269,14 @@ def test_fetch_id_token_credentials_from_explicit_cred_json_file(monkeypatch): assert cred._target_audience == ID_TOKEN_AUDIENCE +def test_fetch_id_token_credentials_from_impersonated_cred_json_file(monkeypatch): + monkeypatch.setenv(environment_vars.CREDENTIALS, IMPERSONATED_SERVICE_ACCOUNT_FILE) + + cred = id_token.fetch_id_token_credentials(ID_TOKEN_AUDIENCE) + assert isinstance(cred, impersonated_credentials.IDTokenCredentials) + assert cred._target_audience == ID_TOKEN_AUDIENCE + + def test_fetch_id_token_credentials_no_cred_exists(monkeypatch): monkeypatch.delenv(environment_vars.CREDENTIALS, raising=False) diff --git a/tests/test_impersonated_credentials.py b/tests/test_impersonated_credentials.py index 8f6b22670..4aa357e3e 100644 --- a/tests/test_impersonated_credentials.py +++ b/tests/test_impersonated_credentials.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import datetime import http.client as http_client import json @@ -35,6 +36,9 @@ PRIVATE_KEY_BYTES = fh.read() SERVICE_ACCOUNT_JSON_FILE = os.path.join(DATA_DIR, "service_account.json") +IMPERSONATED_SERVICE_ACCOUNT_AUTHORIZED_USER_SOURCE_FILE = os.path.join( + DATA_DIR, "impersonated_service_account_authorized_user_source.json" +) ID_TOKEN_DATA = ( "eyJhbGciOiJSUzI1NiIsImtpZCI6ImRmMzc1ODkwOGI3OTIyOTNhZDk3N2Ew" @@ -49,6 +53,9 @@ with open(SERVICE_ACCOUNT_JSON_FILE, "rb") as fh: SERVICE_ACCOUNT_INFO = json.load(fh) +with open(IMPERSONATED_SERVICE_ACCOUNT_AUTHORIZED_USER_SOURCE_FILE, "rb") as fh: + IMPERSONATED_SERVICE_ACCOUNT_AUTHORIZED_USER_SOURCE_INFO = json.load(fh) + SIGNER = crypt.RSASigner.from_string(PRIVATE_KEY_BYTES, "1") TOKEN_URI = "https://example.com/oauth2/token" @@ -148,6 +155,38 @@ def make_credentials( iam_endpoint_override=iam_endpoint_override, ) + def test_from_impersonated_service_account_info(self): + credentials = impersonated_credentials.Credentials.from_impersonated_service_account_info( + IMPERSONATED_SERVICE_ACCOUNT_AUTHORIZED_USER_SOURCE_INFO + ) + assert isinstance(credentials, impersonated_credentials.Credentials) + + def test_from_impersonated_service_account_info_with_invalid_source_credentials_type( + self + ): + info = copy.deepcopy(IMPERSONATED_SERVICE_ACCOUNT_AUTHORIZED_USER_SOURCE_INFO) + assert "source_credentials" in info + # Set the source_credentials to an invalid type + info["source_credentials"]["type"] = "invalid_type" + with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: + impersonated_credentials.Credentials.from_impersonated_service_account_info( + info + ) + assert excinfo.match( + "source credential of type {} is not supported".format("invalid_type") + ) + + def test_from_impersonated_service_account_info_with_invalid_impersonation_url( + self + ): + info = copy.deepcopy(IMPERSONATED_SERVICE_ACCOUNT_AUTHORIZED_USER_SOURCE_INFO) + info["service_account_impersonation_url"] = "invalid_url" + with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: + impersonated_credentials.Credentials.from_impersonated_service_account_info( + info + ) + assert excinfo.match(r"Cannot extract target principal from") + def test_get_cred_info(self): credentials = self.make_credentials() assert not credentials.get_cred_info()