import io from botocore.httpchecksum import ( _CHECKSUM_CLS, AwsChunkedWrapper, FlexibleChecksumError, _apply_request_header_checksum, _handle_streaming_response, base64, conditionally_calculate_md5, determine_content_length, logger, ) from aiobotocore._helpers import resolve_awaitable class AioAwsChunkedWrapper(AwsChunkedWrapper): async def _make_chunk(self): # NOTE: Chunk size is not deterministic as read could return less. This # means we cannot know the content length of the encoded aws-chunked # stream ahead of time without ensuring a consistent chunk size raw_chunk = await resolve_awaitable(self._raw.read(self._chunk_size)) hex_len = hex(len(raw_chunk))[2:].encode("ascii") self._complete = not raw_chunk if self._checksum: self._checksum.update(raw_chunk) if self._checksum and self._complete: name = self._checksum_name.encode("ascii") checksum = self._checksum.b64digest().encode("ascii") return b"0\r\n%s:%s\r\n\r\n" % (name, checksum) return b"%s\r\n%s\r\n" % (hex_len, raw_chunk) def __aiter__(self): return self async def __anext__(self): while not self._complete: return await self._make_chunk() raise StopAsyncIteration() async def handle_checksum_body( http_response, response, context, operation_model ): headers = response["headers"] checksum_context = context.get("checksum", {}) algorithms = checksum_context.get("response_algorithms") if not algorithms: return for algorithm in algorithms: header_name = "x-amz-checksum-%s" % algorithm # If the header is not found, check the next algorithm if header_name not in headers: continue # If a - is in the checksum this is not valid Base64. S3 returns # checksums that include a -# suffix to indicate a checksum derived # from the hash of all part checksums. We cannot wrap this response if "-" in headers[header_name]: continue if operation_model.has_streaming_output: response["body"] = _handle_streaming_response( http_response, response, algorithm ) else: response["body"] = await _handle_bytes_response( http_response, response, algorithm ) # Expose metadata that the checksum check actually occurred checksum_context = response["context"].get("checksum", {}) checksum_context["response_algorithm"] = algorithm response["context"]["checksum"] = checksum_context return logger.info( f'Skipping checksum validation. Response did not contain one of the ' f'following algorithms: {algorithms}.' ) async def _handle_bytes_response(http_response, response, algorithm): body = await http_response.content header_name = "x-amz-checksum-%s" % algorithm checksum_cls = _CHECKSUM_CLS.get(algorithm) checksum = checksum_cls() checksum.update(body) expected = response["headers"][header_name] if checksum.digest() != base64.b64decode(expected): error_msg = ( "Expected checksum %s did not match calculated checksum: %s" % ( expected, checksum.b64digest(), ) ) raise FlexibleChecksumError(error_msg=error_msg) return body def apply_request_checksum(request): checksum_context = request.get("context", {}).get("checksum", {}) algorithm = checksum_context.get("request_algorithm") if not algorithm: return if algorithm == "conditional-md5": # Special case to handle the http checksum required trait conditionally_calculate_md5(request) elif algorithm["in"] == "header": _apply_request_header_checksum(request) elif algorithm["in"] == "trailer": _apply_request_trailer_checksum(request) else: raise FlexibleChecksumError( error_msg="Unknown checksum variant: %s" % algorithm["in"] ) def _apply_request_trailer_checksum(request): checksum_context = request.get("context", {}).get("checksum", {}) algorithm = checksum_context.get("request_algorithm") location_name = algorithm["name"] checksum_cls = _CHECKSUM_CLS.get(algorithm["algorithm"]) headers = request["headers"] body = request["body"] if location_name in headers: # If the header is already set by the customer, skip calculation return # Cannot set this as aiohttp complains headers["Transfer-Encoding"] = "chunked" if "Content-Encoding" in headers: # We need to preserve the existing content encoding and add # aws-chunked as a new content encoding. headers["Content-Encoding"] += ",aws-chunked" else: headers["Content-Encoding"] = "aws-chunked" headers["X-Amz-Trailer"] = location_name content_length = determine_content_length(body) if content_length is not None: # Send the decoded content length if we can determine it. Some # services such as S3 may require the decoded content length headers["X-Amz-Decoded-Content-Length"] = str(content_length) if isinstance(body, (bytes, bytearray)): body = io.BytesIO(body) request["body"] = AioAwsChunkedWrapper( body, checksum_cls=checksum_cls, checksum_name=location_name, )