import asyncio import logging from datetime import timedelta import dateutil.parser from botocore.compat import total_seconds from botocore.exceptions import ClientError, TokenRetrievalError from botocore.tokens import ( DeferredRefreshableToken, FrozenAuthToken, SSOTokenProvider, TokenProviderChain, _utc_now, ) logger = logging.getLogger(__name__) def create_token_resolver(session): providers = [ AioSSOTokenProvider(session), ] return TokenProviderChain(providers=providers) class AioDeferredRefreshableToken(DeferredRefreshableToken): def __init__( self, method, refresh_using, time_fetcher=_utc_now ): # noqa: E501, lgtm [py/missing-call-to-init] self._time_fetcher = time_fetcher self._refresh_using = refresh_using self.method = method # The frozen token is protected by this lock self._refresh_lock = asyncio.Lock() self._frozen_token = None self._next_refresh = None async def get_frozen_token(self): await self._refresh() return self._frozen_token async def _refresh(self): # If we don't need to refresh just return refresh_type = self._should_refresh() if not refresh_type: return None # Block for refresh if we're in the mandatory refresh window block_for_refresh = refresh_type == "mandatory" if block_for_refresh or not self._refresh_lock.locked(): async with self._refresh_lock: await self._protected_refresh() async def _protected_refresh(self): # This should only be called after acquiring the refresh lock # Another task may have already refreshed, double check refresh refresh_type = self._should_refresh() if not refresh_type: return None try: now = self._time_fetcher() self._next_refresh = now + timedelta(seconds=self._attempt_timeout) self._frozen_token = await self._refresh_using() except Exception: logger.warning( "Refreshing token failed during the %s refresh period.", refresh_type, exc_info=True, ) if refresh_type == "mandatory": # This refresh was mandatory, error must be propagated back raise if self._is_expired(): # Fresh credentials should never be expired raise TokenRetrievalError( provider=self.method, error_msg="Token has expired and refresh failed", ) class AioSSOTokenProvider(SSOTokenProvider): async def _attempt_create_token(self, token): async with self._client as client: response = await client.create_token( grantType=self._GRANT_TYPE, clientId=token["clientId"], clientSecret=token["clientSecret"], refreshToken=token["refreshToken"], ) expires_in = timedelta(seconds=response["expiresIn"]) new_token = { "startUrl": self._sso_config["sso_start_url"], "region": self._sso_config["sso_region"], "accessToken": response["accessToken"], "expiresAt": self._now() + expires_in, # Cache the registration alongside the token "clientId": token["clientId"], "clientSecret": token["clientSecret"], "registrationExpiresAt": token["registrationExpiresAt"], } if "refreshToken" in response: new_token["refreshToken"] = response["refreshToken"] logger.info("SSO Token refresh succeeded") return new_token async def _refresh_access_token(self, token): keys = ( "refreshToken", "clientId", "clientSecret", "registrationExpiresAt", ) missing_keys = [k for k in keys if k not in token] if missing_keys: msg = f"Unable to refresh SSO token: missing keys: {missing_keys}" logger.info(msg) return None expiry = dateutil.parser.parse(token["registrationExpiresAt"]) if total_seconds(expiry - self._now()) <= 0: logger.info(f"SSO token registration expired at {expiry}") return None try: return await self._attempt_create_token(token) except ClientError: logger.warning("SSO token refresh attempt failed", exc_info=True) return None async def _refresher(self): start_url = self._sso_config["sso_start_url"] session_name = self._sso_config["session_name"] logger.info(f"Loading cached SSO token for {session_name}") token_dict = self._token_loader(start_url, session_name=session_name) expiration = dateutil.parser.parse(token_dict["expiresAt"]) logger.debug(f"Cached SSO token expires at {expiration}") remaining = total_seconds(expiration - self._now()) if remaining < self._REFRESH_WINDOW: new_token_dict = await self._refresh_access_token(token_dict) if new_token_dict is not None: token_dict = new_token_dict expiration = token_dict["expiresAt"] self._token_loader.save_token( start_url, token_dict, session_name=session_name ) return FrozenAuthToken( token_dict["accessToken"], expiration=expiration ) def load_token(self): if self._sso_config is None: return None return AioDeferredRefreshableToken( self.METHOD, self._refresher, time_fetcher=self._now )