| # Copyright 2016 Google Inc. All rights reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| import logging |
| |
| import httplib2 |
| import six |
| from six.moves import http_client |
| |
| from oauth2client._helpers import _to_bytes |
| |
| |
| _LOGGER = logging.getLogger(__name__) |
| # Properties present in file-like streams / buffers. |
| _STREAM_PROPERTIES = ('read', 'seek', 'tell') |
| |
| # Google Data client libraries may need to set this to [401, 403]. |
| REFRESH_STATUS_CODES = (http_client.UNAUTHORIZED,) |
| |
| |
| class MemoryCache(object): |
| """httplib2 Cache implementation which only caches locally.""" |
| |
| def __init__(self): |
| self.cache = {} |
| |
| def get(self, key): |
| return self.cache.get(key) |
| |
| def set(self, key, value): |
| self.cache[key] = value |
| |
| def delete(self, key): |
| self.cache.pop(key, None) |
| |
| |
| def get_cached_http(): |
| """Return an HTTP object which caches results returned. |
| |
| This is intended to be used in methods like |
| oauth2client.client.verify_id_token(), which calls to the same URI |
| to retrieve certs. |
| |
| Returns: |
| httplib2.Http, an HTTP object with a MemoryCache |
| """ |
| return _CACHED_HTTP |
| |
| |
| def get_http_object(): |
| """Return a new HTTP object. |
| |
| Returns: |
| httplib2.Http, an HTTP object. |
| """ |
| return httplib2.Http() |
| |
| |
| def _initialize_headers(headers): |
| """Creates a copy of the headers. |
| |
| Args: |
| headers: dict, request headers to copy. |
| |
| Returns: |
| dict, the copied headers or a new dictionary if the headers |
| were None. |
| """ |
| return {} if headers is None else dict(headers) |
| |
| |
| def _apply_user_agent(headers, user_agent): |
| """Adds a user-agent to the headers. |
| |
| Args: |
| headers: dict, request headers to add / modify user |
| agent within. |
| user_agent: str, the user agent to add. |
| |
| Returns: |
| dict, the original headers passed in, but modified if the |
| user agent is not None. |
| """ |
| if user_agent is not None: |
| if 'user-agent' in headers: |
| headers['user-agent'] = (user_agent + ' ' + headers['user-agent']) |
| else: |
| headers['user-agent'] = user_agent |
| |
| return headers |
| |
| |
| def clean_headers(headers): |
| """Forces header keys and values to be strings, i.e not unicode. |
| |
| The httplib module just concats the header keys and values in a way that |
| may make the message header a unicode string, which, if it then tries to |
| contatenate to a binary request body may result in a unicode decode error. |
| |
| Args: |
| headers: dict, A dictionary of headers. |
| |
| Returns: |
| The same dictionary but with all the keys converted to strings. |
| """ |
| clean = {} |
| try: |
| for k, v in six.iteritems(headers): |
| if not isinstance(k, six.binary_type): |
| k = str(k) |
| if not isinstance(v, six.binary_type): |
| v = str(v) |
| clean[_to_bytes(k)] = _to_bytes(v) |
| except UnicodeEncodeError: |
| from oauth2client.client import NonAsciiHeaderError |
| raise NonAsciiHeaderError(k, ': ', v) |
| return clean |
| |
| |
| def wrap_http_for_auth(credentials, http): |
| """Prepares an HTTP object's request method for auth. |
| |
| Wraps HTTP requests with logic to catch auth failures (typically |
| identified via a 401 status code). In the event of failure, tries |
| to refresh the token used and then retry the original request. |
| |
| Args: |
| credentials: Credentials, the credentials used to identify |
| the authenticated user. |
| http: httplib2.Http, an http object to be used to make |
| auth requests. |
| """ |
| orig_request_method = http.request |
| |
| # The closure that will replace 'httplib2.Http.request'. |
| def new_request(uri, method='GET', body=None, headers=None, |
| redirections=httplib2.DEFAULT_MAX_REDIRECTS, |
| connection_type=None): |
| if not credentials.access_token: |
| _LOGGER.info('Attempting refresh to obtain ' |
| 'initial access_token') |
| credentials._refresh(orig_request_method) |
| |
| # Clone and modify the request headers to add the appropriate |
| # Authorization header. |
| headers = _initialize_headers(headers) |
| credentials.apply(headers) |
| _apply_user_agent(headers, credentials.user_agent) |
| |
| body_stream_position = None |
| # Check if the body is a file-like stream. |
| if all(getattr(body, stream_prop, None) for stream_prop in |
| _STREAM_PROPERTIES): |
| body_stream_position = body.tell() |
| |
| resp, content = orig_request_method(uri, method, body, |
| clean_headers(headers), |
| redirections, connection_type) |
| |
| # A stored token may expire between the time it is retrieved and |
| # the time the request is made, so we may need to try twice. |
| max_refresh_attempts = 2 |
| for refresh_attempt in range(max_refresh_attempts): |
| if resp.status not in REFRESH_STATUS_CODES: |
| break |
| _LOGGER.info('Refreshing due to a %s (attempt %s/%s)', |
| resp.status, refresh_attempt + 1, |
| max_refresh_attempts) |
| credentials._refresh(orig_request_method) |
| credentials.apply(headers) |
| if body_stream_position is not None: |
| body.seek(body_stream_position) |
| |
| resp, content = orig_request_method(uri, method, body, |
| clean_headers(headers), |
| redirections, connection_type) |
| |
| return resp, content |
| |
| # Replace the request method with our own closure. |
| http.request = new_request |
| |
| # Set credentials as a property of the request method. |
| setattr(http.request, 'credentials', credentials) |
| |
| |
| def wrap_http_for_jwt_access(credentials, http): |
| """Prepares an HTTP object's request method for JWT access. |
| |
| Wraps HTTP requests with logic to catch auth failures (typically |
| identified via a 401 status code). In the event of failure, tries |
| to refresh the token used and then retry the original request. |
| |
| Args: |
| credentials: _JWTAccessCredentials, the credentials used to identify |
| a service account that uses JWT access tokens. |
| http: httplib2.Http, an http object to be used to make |
| auth requests. |
| """ |
| orig_request_method = http.request |
| wrap_http_for_auth(credentials, http) |
| # The new value of ``http.request`` set by ``wrap_http_for_auth``. |
| authenticated_request_method = http.request |
| |
| # The closure that will replace 'httplib2.Http.request'. |
| def new_request(uri, method='GET', body=None, headers=None, |
| redirections=httplib2.DEFAULT_MAX_REDIRECTS, |
| connection_type=None): |
| if 'aud' in credentials._kwargs: |
| # Preemptively refresh token, this is not done for OAuth2 |
| if (credentials.access_token is None or |
| credentials.access_token_expired): |
| credentials.refresh(None) |
| return authenticated_request_method(uri, method, body, |
| headers, redirections, |
| connection_type) |
| else: |
| # If we don't have an 'aud' (audience) claim, |
| # create a 1-time token with the uri root as the audience |
| headers = _initialize_headers(headers) |
| _apply_user_agent(headers, credentials.user_agent) |
| uri_root = uri.split('?', 1)[0] |
| token, unused_expiry = credentials._create_token({'aud': uri_root}) |
| |
| headers['Authorization'] = 'Bearer ' + token |
| return orig_request_method(uri, method, body, |
| clean_headers(headers), |
| redirections, connection_type) |
| |
| # Replace the request method with our own closure. |
| http.request = new_request |
| |
| |
| _CACHED_HTTP = httplib2.Http(MemoryCache()) |