diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 3ccb781a0..000000000 Binary files a/.DS_Store and /dev/null differ diff --git a/google/auth/external_account.py b/google/auth/external_account.py index 6b1764a0a..0a6c3cab2 100644 --- a/google/auth/external_account.py +++ b/google/auth/external_account.py @@ -33,6 +33,8 @@ from google.auth import _helpers from google.auth import credentials +from google.auth import exceptions +from google.auth import impersonated_credentials from google.oauth2 import sts from google.oauth2 import utils @@ -58,6 +60,7 @@ def __init__( subject_token_type, token_url, credential_source, + service_account_impersonation_url=None, client_id=None, client_secret=None, quota_project_id=None, @@ -70,17 +73,23 @@ def __init__( subject_token_type (str): The subject token type. token_url (str): The STS endpoint URL. credential_source (Mapping): The credential source dictionary. + service_account_impersonation_url (Optional[str]): The optional service account + impersonation generateAccessToken URL. client_id (Optional[str]): The optional client ID. client_secret (Optional[str]): The optional client secret. quota_project_id (Optional[str]): The optional quota project ID. scopes (Optional[Sequence[str]]): Optional scopes to request during the authorization grant. + Raises: + google.auth.exceptions.RefreshError: If the generateAccessToken + endpoint returned an error. """ super(Credentials, self).__init__() self._audience = audience self._subject_token_type = subject_token_type self._token_url = token_url self._credential_source = credential_source + self._service_account_impersonation_url = service_account_impersonation_url self._client_id = client_id self._client_secret = client_secret self._quota_project_id = quota_project_id @@ -94,6 +103,11 @@ def __init__( self._client_auth = None self._sts_client = sts.Client(self._token_url, self._client_auth) + if self._service_account_impersonation_url: + self._impersonated_credentials = self._initialize_impersonated_credentials() + else: + self._impersonated_credentials = None + @property def requires_scopes(self): """Checks if the credentials requires scopes. @@ -132,20 +146,24 @@ def retrieve_subject_token(self, request): @_helpers.copy_docstring(credentials.Credentials) def refresh(self, request): - now = _helpers.utcnow() - response_data = self._sts_client.exchange_token( - request=request, - grant_type=_STS_GRANT_TYPE, - subject_token=self.retrieve_subject_token(request), - subject_token_type=self._subject_token_type, - audience=self._audience, - scopes=self._scopes, - requested_token_type=_STS_REQUESTED_TOKEN_TYPE, - ) - - self.token = response_data.get("access_token") - lifetime = datetime.timedelta(seconds=response_data.get("expires_in")) - self.expiry = now + lifetime + if self._impersonated_credentials: + self._impersonated_credentials.refresh(request) + self.token = self._impersonated_credentials.token + self.expiry = self._impersonated_credentials.expiry + else: + now = _helpers.utcnow() + response_data = self._sts_client.exchange_token( + request=request, + grant_type=_STS_GRANT_TYPE, + subject_token=self.retrieve_subject_token(request), + subject_token_type=self._subject_token_type, + audience=self._audience, + scopes=self._scopes, + requested_token_type=_STS_REQUESTED_TOKEN_TYPE, + ) + self.token = response_data.get("access_token") + lifetime = datetime.timedelta(seconds=response_data.get("expires_in")) + self.expiry = now + lifetime @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) def with_quota_project(self, quota_project_id): @@ -155,8 +173,59 @@ def with_quota_project(self, quota_project_id): subject_token_type=self._subject_token_type, token_url=self._token_url, credential_source=self._credential_source, + service_account_impersonation_url=self._service_account_impersonation_url, client_id=self._client_id, client_secret=self._client_secret, quota_project_id=quota_project_id, scopes=self._scopes, ) + + def _initialize_impersonated_credentials(self): + """Generates an impersonated credentials. + + For more details, see `projects.serviceAccounts.generateAccessToken`_. + + .. _projects.serviceAccounts.generateAccessToken: https://cloud.google.com/iam/docs/reference/credentials/rest/v1/projects.serviceAccounts/generateAccessToken + + Returns: + impersonated_credentials.Credential: The impersonated credentials + object. + + Raises: + google.auth.exceptions.RefreshError: If the generateAccessToken + endpoint returned an error. + """ + # Return copy of instance with no service account impersonation. + source_credentials = self.__class__( + audience=self._audience, + subject_token_type=self._subject_token_type, + token_url=self._token_url, + credential_source=self._credential_source, + service_account_impersonation_url=None, + client_id=self._client_id, + client_secret=self._client_secret, + quota_project_id=self._quota_project_id, + scopes=self._scopes, + ) + + # Determine target_principal. + start_index = self._service_account_impersonation_url.rfind("/") + end_index = self._service_account_impersonation_url.find(":generateAccessToken") + if start_index != -1 and end_index != -1 and start_index < end_index: + start_index = start_index + 1 + target_principal = self._service_account_impersonation_url[ + start_index:end_index + ] + else: + raise exceptions.RefreshError( + "Unable to determine target principal from service account impersonation URL." + ) + + # Initialize and return impersonated credentials. + return impersonated_credentials.Credentials( + source_credentials=source_credentials, + target_principal=target_principal, + target_scopes=self._scopes, + quota_project_id=self._quota_project_id, + iam_endpoint_override=self._service_account_impersonation_url, + ) diff --git a/google/auth/impersonated_credentials.py b/google/auth/impersonated_credentials.py index d2c5ded1c..e2d0b3a82 100644 --- a/google/auth/impersonated_credentials.py +++ b/google/auth/impersonated_credentials.py @@ -65,7 +65,9 @@ _DEFAULT_TOKEN_URI = "https://oauth2.googleapis.com/token" -def _make_iam_token_request(request, principal, headers, body): +def _make_iam_token_request( + request, principal, headers, body, iam_endpoint_override=None +): """Makes a request to the Google Cloud IAM service for an access token. Args: request (Request): The Request object to use. @@ -73,6 +75,9 @@ def _make_iam_token_request(request, principal, headers, body): headers (Mapping[str, str]): Map of headers to transmit. body (Mapping[str, str]): JSON Payload body for the iamcredentials API call. + iam_endpoint_override (Optiona[str]): The full IAM endpoint override + with the target_principal embedded. This is useful when supporting + impersonation with regional endpoints. Raises: google.auth.exceptions.TransportError: Raised if there is an underlying @@ -82,7 +87,7 @@ def _make_iam_token_request(request, principal, headers, body): `iamcredentials.googleapis.com` is not enabled or the `Service Account Token Creator` is not assigned """ - iam_endpoint = _IAM_ENDPOINT.format(principal) + iam_endpoint = iam_endpoint_override or _IAM_ENDPOINT.format(principal) body = json.dumps(body).encode("utf-8") @@ -185,6 +190,7 @@ def __init__( delegates=None, lifetime=_DEFAULT_TOKEN_LIFETIME_SECS, quota_project_id=None, + iam_endpoint_override=None, ): """ Args: @@ -209,6 +215,9 @@ def __init__( quota_project_id (Optional[str]): The project ID used for quota and billing. This project may be different from the project used to create the credentials. + iam_endpoint_override (Optiona[str]): The full IAM endpoint override + with the target_principal embedded. This is useful when supporting + impersonation with regional endpoints. """ super(Credentials, self).__init__() @@ -226,6 +235,7 @@ def __init__( self.token = None self.expiry = _helpers.utcnow() self._quota_project_id = quota_project_id + self._iam_endpoint_override = iam_endpoint_override @_helpers.copy_docstring(credentials.Credentials) def refresh(self, request): @@ -260,6 +270,7 @@ def _update_token(self, request): principal=self._target_principal, headers=headers, body=body, + iam_endpoint_override=self._iam_endpoint_override, ) def sign_bytes(self, message): @@ -302,6 +313,7 @@ def with_quota_project(self, quota_project_id): delegates=self._delegates, lifetime=self._lifetime, quota_project_id=quota_project_id, + iam_endpoint_override=self._iam_endpoint_override, ) diff --git a/tests/test_external_account.py b/tests/test_external_account.py index 0c8dcccc0..c33933979 100644 --- a/tests/test_external_account.py +++ b/tests/test_external_account.py @@ -30,6 +30,7 @@ CLIENT_SECRET = "password" # Base64 encoding of "username:password" BASIC_AUTH_ENCODING = "dXNlcm5hbWU6cGFzc3dvcmQ=" +SERVICE_ACCOUNT_EMAIL = "service-1234@service-name.iam.gserviceaccount.com" class CredentialsImpl(external_account.Credentials): @@ -39,6 +40,7 @@ def __init__( subject_token_type, token_url, credential_source, + service_account_impersonation_url=None, client_id=None, client_secret=None, quota_project_id=None, @@ -49,6 +51,7 @@ def __init__( subject_token_type, token_url, credential_source, + service_account_impersonation_url, client_id, client_secret, quota_project_id, @@ -87,15 +90,33 @@ class TestCredentials(object): "error_uri": "https://tools.ietf.org/html/rfc6749", } QUOTA_PROJECT_ID = "QUOTA_PROJECT_ID" + SERVICE_ACCOUNT_IMPERSONATION_URL = ( + "https://us-east1-iamcredentials.googleapis.com/v1/projects/-" + + "/serviceAccounts/{}:generateAccessToken".format(SERVICE_ACCOUNT_EMAIL) + ) + SCOPES = ["scope1", "scope2"] + IMPERSONATION_ERROR_RESPONSE = { + "error": { + "code": 400, + "message": "Request contains an invalid argument", + "status": "INVALID_ARGUMENT", + } + } @classmethod def make_credentials( - cls, client_id=None, client_secret=None, quota_project_id=None, scopes=None + cls, + client_id=None, + client_secret=None, + quota_project_id=None, + scopes=None, + service_account_impersonation_url=None, ): return CredentialsImpl( audience=cls.AUDIENCE, subject_token_type=cls.SUBJECT_TOKEN_TYPE, token_url=cls.TOKEN_URL, + service_account_impersonation_url=service_account_impersonation_url, credential_source=cls.CREDENTIAL_SOURCE, client_id=client_id, client_secret=client_secret, @@ -104,18 +125,35 @@ def make_credentials( ) @classmethod - def make_mock_request(cls, data, status=http_client.OK): - response = mock.create_autospec(transport.Response, instance=True) - response.status = status - response.data = json.dumps(data).encode("utf-8") + def make_mock_request( + cls, + data, + status=http_client.OK, + impersonation_data=None, + impersonation_status=None, + ): + # STS token exchange request. + token_response = mock.create_autospec(transport.Response, instance=True) + token_response.status = status + token_response.data = json.dumps(data).encode("utf-8") + responses = [token_response] + + # If service account impersonation is requested, mock the expected response. + if impersonation_status and impersonation_status: + impersonation_response = mock.create_autospec( + transport.Response, instance=True + ) + impersonation_response.status = impersonation_status + impersonation_response.data = json.dumps(impersonation_data).encode("utf-8") + responses.append(impersonation_response) request = mock.create_autospec(transport.Request) - request.return_value = response + request.side_effect = responses return request @classmethod - def assert_request_kwargs(cls, request_kwargs, headers, request_data): + def assert_token_request_kwargs(cls, request_kwargs, headers, request_data): assert request_kwargs["url"] == cls.TOKEN_URL assert request_kwargs["method"] == "POST" assert request_kwargs["headers"] == headers @@ -125,6 +163,15 @@ def assert_request_kwargs(cls, request_kwargs, headers, request_data): assert v.decode("utf-8") == request_data[k.decode("utf-8")] assert len(body_tuples) == len(request_data.keys()) + @classmethod + def assert_impersonation_request_kwargs(cls, request_kwargs, headers, request_data): + assert request_kwargs["url"] == cls.SERVICE_ACCOUNT_IMPERSONATION_URL + assert request_kwargs["method"] == "POST" + assert request_kwargs["headers"] == headers + assert request_kwargs["body"] is not None + body_json = json.loads(request_kwargs["body"].decode("utf-8")) + assert body_json == request_data + def test_default_state(self): credentials = self.make_credentials() @@ -160,6 +207,16 @@ def test_with_quota_project(self): assert quota_project_creds.quota_project_id == "project-foo" + def test_with_invalid_impersonation_target_principal(self): + invalid_url = "https://iamcredentials.googleapis.com/v1/invalid" + + with pytest.raises(exceptions.RefreshError) as excinfo: + self.make_credentials(service_account_impersonation_url=invalid_url) + + assert excinfo.match( + r"Unable to determine target principal from service account impersonation URL." + ) + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) def test_refresh_without_client_auth_success(self, unused_utcnow): response = self.SUCCESS_RESPONSE.copy() @@ -181,12 +238,78 @@ def test_refresh_without_client_auth_success(self, unused_utcnow): credentials.refresh(request) - self.assert_request_kwargs(request.call_args.kwargs, headers, request_data) + self.assert_token_request_kwargs( + request.call_args.kwargs, headers, request_data + ) assert credentials.valid assert credentials.expiry == expected_expiry assert not credentials.expired assert credentials.token == response["access_token"] + def test_refresh_impersonation_without_client_auth_success(self): + # Simulate service account access token expires in 2800 seconds. + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=2800) + ).isoformat("T") + "Z" + expected_expiry = datetime.datetime.strptime(expire_time, "%Y-%m-%dT%H:%M:%SZ") + # STS token exchange request/response. + token_response = self.SUCCESS_RESPONSE.copy() + token_headers = {"Content-Type": "application/x-www-form-urlencoded"} + token_request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": self.AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "subject_token": "subject_token_0", + "subject_token_type": self.SUBJECT_TOKEN_TYPE, + "scope": "https://www.googleapis.com/auth/iam", + } + # Service account impersonation request/response. + impersonation_response = { + "accessToken": "SA_ACCESS_TOKEN", + "expireTime": expire_time, + } + impersonation_headers = { + "Content-Type": "application/json", + "authorization": "Bearer {}".format(token_response["access_token"]), + } + impersonation_request_data = { + "delegates": None, + "scope": self.SCOPES, + "lifetime": "3600s", + } + # Initialize mock request to handle token exchange and service account + # impersonation request. + request = self.make_mock_request( + status=http_client.OK, + data=token_response, + impersonation_status=http_client.OK, + impersonation_data=impersonation_response, + ) + # Initialize credentials with service account impersonation. + credentials = self.make_credentials( + service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=self.SCOPES, + ) + + credentials.refresh(request) + + # Only 2 requests should be processed. + assert len(request.call_args_list) == 2 + # Verify token exchange request parameters. + self.assert_token_request_kwargs( + request.call_args_list[0].kwargs, token_headers, token_request_data + ) + # Verify service account impersonation request parameters. + self.assert_impersonation_request_kwargs( + request.call_args_list[1].kwargs, + impersonation_headers, + impersonation_request_data, + ) + assert credentials.valid + assert credentials.expiry == expected_expiry + assert not credentials.expired + assert credentials.token == impersonation_response["accessToken"] + def test_refresh_without_client_auth_success_explicit_scopes(self): headers = {"Content-Type": "application/x-www-form-urlencoded"} request_data = { @@ -204,7 +327,9 @@ def test_refresh_without_client_auth_success_explicit_scopes(self): credentials.refresh(request) - self.assert_request_kwargs(request.call_args.kwargs, headers, request_data) + self.assert_token_request_kwargs( + request.call_args.kwargs, headers, request_data + ) assert credentials.valid assert not credentials.expired assert credentials.token == self.SUCCESS_RESPONSE["access_token"] @@ -225,6 +350,25 @@ def test_refresh_without_client_auth_error(self): assert not credentials.expired assert credentials.token is None + def test_refresh_impersonation_without_client_auth_error(self): + request = self.make_mock_request( + status=http_client.OK, + data=self.SUCCESS_RESPONSE, + impersonation_status=http_client.BAD_REQUEST, + impersonation_data=self.IMPERSONATION_ERROR_RESPONSE, + ) + credentials = self.make_credentials( + service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=self.SCOPES, + ) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.refresh(request) + + assert excinfo.match(r"Unable to acquire impersonated credentials") + assert not credentials.expired + assert credentials.token is None + def test_refresh_with_client_auth_success(self): headers = { "Content-Type": "application/x-www-form-urlencoded", @@ -246,11 +390,82 @@ def test_refresh_with_client_auth_success(self): credentials.refresh(request) - self.assert_request_kwargs(request.call_args.kwargs, headers, request_data) + self.assert_token_request_kwargs( + request.call_args.kwargs, headers, request_data + ) assert credentials.valid assert not credentials.expired assert credentials.token == self.SUCCESS_RESPONSE["access_token"] + def test_refresh_impersonation_with_client_auth_success(self): + # Simulate service account access token expires in 2800 seconds. + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=2800) + ).isoformat("T") + "Z" + expected_expiry = datetime.datetime.strptime(expire_time, "%Y-%m-%dT%H:%M:%SZ") + # STS token exchange request/response. + token_response = self.SUCCESS_RESPONSE.copy() + token_headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": "Basic {}".format(BASIC_AUTH_ENCODING), + } + token_request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": self.AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "subject_token": "subject_token_0", + "subject_token_type": self.SUBJECT_TOKEN_TYPE, + "scope": "https://www.googleapis.com/auth/iam", + } + # Service account impersonation request/response. + impersonation_response = { + "accessToken": "SA_ACCESS_TOKEN", + "expireTime": expire_time, + } + impersonation_headers = { + "Content-Type": "application/json", + "authorization": "Bearer {}".format(token_response["access_token"]), + } + impersonation_request_data = { + "delegates": None, + "scope": self.SCOPES, + "lifetime": "3600s", + } + # Initialize mock request to handle token exchange and service account + # impersonation request. + request = self.make_mock_request( + status=http_client.OK, + data=token_response, + impersonation_status=http_client.OK, + impersonation_data=impersonation_response, + ) + # Initialize credentials with service account impersonation and basic auth. + credentials = self.make_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=self.SCOPES, + ) + + credentials.refresh(request) + + # Only 2 requests should be processed. + assert len(request.call_args_list) == 2 + # Verify token exchange request parameters. + self.assert_token_request_kwargs( + request.call_args_list[0].kwargs, token_headers, token_request_data + ) + # Verify service account impersonation request parameters. + self.assert_impersonation_request_kwargs( + request.call_args_list[1].kwargs, + impersonation_headers, + impersonation_request_data, + ) + assert credentials.valid + assert credentials.expiry == expected_expiry + assert not credentials.expired + assert credentials.token == impersonation_response["accessToken"] + def test_apply_without_quota_project_id(self): headers = {} request = self.make_mock_request( @@ -265,6 +480,37 @@ def test_apply_without_quota_project_id(self): "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]) } + def test_apply_impersonation_without_quota_project_id(self): + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600) + ).isoformat("T") + "Z" + # Service account impersonation response. + impersonation_response = { + "accessToken": "SA_ACCESS_TOKEN", + "expireTime": expire_time, + } + # Initialize mock request to handle token exchange and service account + # impersonation request. + request = self.make_mock_request( + status=http_client.OK, + data=self.SUCCESS_RESPONSE.copy(), + impersonation_status=http_client.OK, + impersonation_data=impersonation_response, + ) + # Initialize credentials with service account impersonation. + credentials = self.make_credentials( + service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=self.SCOPES, + ) + headers = {} + + credentials.refresh(request) + credentials.apply(headers) + + assert headers == { + "authorization": "Bearer {}".format(impersonation_response["accessToken"]) + } + def test_apply_with_quota_project_id(self): headers = {"other": "header-value"} request = self.make_mock_request( @@ -281,6 +527,40 @@ def test_apply_with_quota_project_id(self): "x-goog-user-project": self.QUOTA_PROJECT_ID, } + def test_apply_impersonation_with_quota_project_id(self): + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600) + ).isoformat("T") + "Z" + # Service account impersonation response. + impersonation_response = { + "accessToken": "SA_ACCESS_TOKEN", + "expireTime": expire_time, + } + # Initialize mock request to handle token exchange and service account + # impersonation request. + request = self.make_mock_request( + status=http_client.OK, + data=self.SUCCESS_RESPONSE.copy(), + impersonation_status=http_client.OK, + impersonation_data=impersonation_response, + ) + # Initialize credentials with service account impersonation. + credentials = self.make_credentials( + service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=self.SCOPES, + quota_project_id=self.QUOTA_PROJECT_ID, + ) + headers = {"other": "header-value"} + + credentials.refresh(request) + credentials.apply(headers) + + assert headers == { + "other": "header-value", + "authorization": "Bearer {}".format(impersonation_response["accessToken"]), + "x-goog-user-project": self.QUOTA_PROJECT_ID, + } + def test_before_request(self): headers = {"other": "header-value"} request = self.make_mock_request( @@ -304,6 +584,44 @@ def test_before_request(self): "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]), } + def test_before_request_impersonation(self): + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600) + ).isoformat("T") + "Z" + # Service account impersonation response. + impersonation_response = { + "accessToken": "SA_ACCESS_TOKEN", + "expireTime": expire_time, + } + # Initialize mock request to handle token exchange and service account + # impersonation request. + request = self.make_mock_request( + status=http_client.OK, + data=self.SUCCESS_RESPONSE.copy(), + impersonation_status=http_client.OK, + impersonation_data=impersonation_response, + ) + headers = {"other": "header-value"} + credentials = self.make_credentials( + service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL + ) + + # First call should call refresh, setting the token. + credentials.before_request(request, "POST", "https://example.com/api", headers) + + assert headers == { + "other": "header-value", + "authorization": "Bearer {}".format(impersonation_response["accessToken"]), + } + + # Second call shouldn't call refresh. + credentials.before_request(request, "POST", "https://example.com/api", headers) + + assert headers == { + "other": "header-value", + "authorization": "Bearer {}".format(impersonation_response["accessToken"]), + } + @mock.patch("google.auth._helpers.utcnow") def test_before_request_expired(self, utcnow): headers = {} @@ -339,3 +657,55 @@ def test_before_request_expired(self, utcnow): assert headers == { "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]) } + + @mock.patch("google.auth._helpers.utcnow") + def test_before_request_impersonation_expired(self, utcnow): + headers = {} + expire_time = ( + datetime.datetime.min + datetime.timedelta(seconds=3601) + ).isoformat("T") + "Z" + # Service account impersonation response. + impersonation_response = { + "accessToken": "SA_ACCESS_TOKEN", + "expireTime": expire_time, + } + # Initialize mock request to handle token exchange and service account + # impersonation request. + request = self.make_mock_request( + status=http_client.OK, + data=self.SUCCESS_RESPONSE.copy(), + impersonation_status=http_client.OK, + impersonation_data=impersonation_response, + ) + credentials = self.make_credentials( + service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL + ) + credentials.token = "token" + utcnow.return_value = datetime.datetime.min + # Set the expiration to one second more than now plus the clock skew + # accomodation. These credentials should be valid. + credentials.expiry = ( + datetime.datetime.min + _helpers.CLOCK_SKEW + datetime.timedelta(seconds=1) + ) + + assert credentials.valid + assert not credentials.expired + + credentials.before_request(request, "POST", "https://example.com/api", headers) + + # Cached token should be used. + assert headers == {"authorization": "Bearer token"} + + # Next call should simulate 1 second passed. This will trigger the expiration + # threshold. + utcnow.return_value = datetime.datetime.min + datetime.timedelta(seconds=1) + + assert not credentials.valid + assert credentials.expired + + credentials.before_request(request, "POST", "https://example.com/api", headers) + + # New token should be retrieved. + assert headers == { + "authorization": "Bearer {}".format(impersonation_response["accessToken"]) + } diff --git a/tests/test_impersonated_credentials.py b/tests/test_impersonated_credentials.py index 46850a0d9..10b6c55c0 100644 --- a/tests/test_impersonated_credentials.py +++ b/tests/test_impersonated_credentials.py @@ -104,12 +104,17 @@ class TestImpersonatedCredentials(object): SIGNER, SERVICE_ACCOUNT_EMAIL, TOKEN_URI ) USER_SOURCE_CREDENTIALS = credentials.Credentials(token="ABCDE") + IAM_ENDPOINT_OVERRIDE = ( + "https://us-east1-iamcredentials.googleapis.com/v1/projects/-" + + "/serviceAccounts/{}:generateAccessToken".format(SERVICE_ACCOUNT_EMAIL) + ) def make_credentials( self, source_credentials=SOURCE_CREDENTIALS, lifetime=LIFETIME, target_principal=TARGET_PRINCIPAL, + iam_endpoint_override=None, ): return Credentials( @@ -118,6 +123,7 @@ def make_credentials( target_scopes=self.TARGET_SCOPES, delegates=self.DELEGATES, lifetime=lifetime, + iam_endpoint_override=iam_endpoint_override, ) def test_make_from_user_credentials(self): @@ -172,6 +178,34 @@ def test_refresh_success(self, use_data_bytes, mock_donor_credentials): assert credentials.valid assert not credentials.expired + @pytest.mark.parametrize("use_data_bytes", [True, False]) + def test_refresh_success_iam_endpoint_override( + self, use_data_bytes, mock_donor_credentials + ): + credentials = self.make_credentials( + lifetime=None, iam_endpoint_override=self.IAM_ENDPOINT_OVERRIDE + ) + token = "token" + + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500) + ).isoformat("T") + "Z" + response_body = {"accessToken": token, "expireTime": expire_time} + + request = self.make_request( + data=json.dumps(response_body), + status=http_client.OK, + use_data_bytes=use_data_bytes, + ) + + credentials.refresh(request) + + assert credentials.valid + assert not credentials.expired + # Confirm override endpoint used. + request_kwargs = request.call_args.kwargs + assert request_kwargs["url"] == self.IAM_ENDPOINT_OVERRIDE + @pytest.mark.parametrize("time_skew", [100, -100]) def test_refresh_source_credentials(self, time_skew): credentials = self.make_credentials(lifetime=None) @@ -317,6 +351,36 @@ def test_with_quota_project(self): quota_project_creds = credentials.with_quota_project("project-foo") assert quota_project_creds._quota_project_id == "project-foo" + @pytest.mark.parametrize("use_data_bytes", [True, False]) + def test_with_quota_project_iam_endpoint_override( + self, use_data_bytes, mock_donor_credentials + ): + credentials = self.make_credentials( + lifetime=None, iam_endpoint_override=self.IAM_ENDPOINT_OVERRIDE + ) + token = "token" + # iam_endpoint_override should be copied to created credentials. + quota_project_creds = credentials.with_quota_project("project-foo") + + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500) + ).isoformat("T") + "Z" + response_body = {"accessToken": token, "expireTime": expire_time} + + request = self.make_request( + data=json.dumps(response_body), + status=http_client.OK, + use_data_bytes=use_data_bytes, + ) + + quota_project_creds.refresh(request) + + assert quota_project_creds.valid + assert not quota_project_creds.expired + # Confirm override endpoint used. + request_kwargs = request.call_args.kwargs + assert request_kwargs["url"] == self.IAM_ENDPOINT_OVERRIDE + def test_id_token_success( self, mock_donor_credentials, mock_authorizedsession_idtoken ):