From ec9f5d89971b823fa01071c96b7de90ea5e5303a Mon Sep 17 00:00:00 2001 From: Wilton_ Date: Mon, 25 Oct 2021 20:35:50 -0700 Subject: [PATCH] feat: Accelerate (#3386) Co-authored-by: Jacob Fuss <32497805+jfuss@users.noreply.github.com> Co-authored-by: Mehmet Nuri Deveci <5735811+mndeveci@users.noreply.github.com> Co-authored-by: mingkun2020 <68391979+mingkun2020@users.noreply.github.com> Co-authored-by: Qingchuan Ma <69653965+qingchm@users.noreply.github.com> Co-authored-by: aws-sam-cli-bot <46753707+aws-sam-cli-bot@users.noreply.github.com> --- .gitignore | 5 + appveyor.yml | 2 +- requirements/base.txt | 9 +- requirements/reproducible-linux.txt | 17 +- samcli/cli/command.py | 2 + samcli/cli/context.py | 9 + samcli/cli/global_config.py | 444 +++++++++---- samcli/cli/main.py | 8 +- samcli/commands/_utils/experimental.py | 229 +++++++ samcli/commands/_utils/options.py | 379 ++++++++++- samcli/commands/_utils/template.py | 8 +- samcli/commands/build/build_context.py | 218 ++++++- samcli/commands/build/command.py | 168 +---- samcli/commands/delete/delete_context.py | 2 +- samcli/commands/deploy/command.py | 123 +--- samcli/commands/deploy/deploy_context.py | 82 ++- samcli/commands/deploy/utils.py | 5 +- samcli/commands/init/init_templates.py | 4 +- samcli/commands/logs/command.py | 140 ++-- samcli/commands/logs/console_consumers.py | 14 +- samcli/commands/logs/logs_context.py | 320 +++------ samcli/commands/logs/puller_factory.py | 178 +++++ samcli/commands/package/command.py | 94 +-- samcli/commands/package/package_context.py | 2 +- .../pipeline/init/interactive_init_flow.py | 4 +- samcli/commands/sync/__init__.py | 4 + samcli/commands/sync/command.py | 415 ++++++++++++ samcli/commands/traces/__init__.py | 6 + samcli/commands/traces/command.py | 78 +++ .../traces/trace_console_consumers.py | 18 + .../commands/traces/traces_puller_factory.py | 112 ++++ .../validate/lib/sam_template_validator.py | 2 +- .../companion_stack_builder.py | 41 +- samcli/lib/bootstrap/nested_stack/__init__.py | 0 .../nested_stack/nested_stack_builder.py | 78 +++ .../nested_stack/nested_stack_manager.py | 194 ++++++ samcli/lib/bootstrap/stack_builder.py | 58 ++ samcli/lib/build/app_builder.py | 101 ++- samcli/lib/build/build_graph.py | 192 +++++- samcli/lib/build/build_strategy.py | 261 +++++++- samcli/lib/build/dependency_hash_generator.py | 86 +++ samcli/lib/build/exceptions.py | 5 + samcli/lib/deploy/deployer.py | 126 +++- samcli/lib/iac/cfn/cfn_iac.py | 14 +- .../cw_logs/cw_log_formatters.py | 29 + .../cw_logs/cw_log_group_provider.py | 128 +++- .../observability/cw_logs/cw_log_puller.py | 50 +- .../observability_info_puller.py | 79 ++- .../lib/observability/xray_traces/__init__.py | 0 .../xray_traces/xray_event_mappers.py | 166 +++++ .../xray_traces/xray_event_puller.py | 152 +++++ .../observability/xray_traces/xray_events.py | 160 +++++ .../xray_service_graph_event_puller.py | 72 ++ samcli/lib/package/artifact_exporter.py | 2 +- samcli/lib/package/packageable_resources.py | 2 +- samcli/lib/providers/cfn_api_provider.py | 53 +- samcli/lib/providers/exceptions.py | 39 ++ samcli/lib/providers/provider.py | 168 ++++- samcli/lib/providers/sam_api_provider.py | 14 +- samcli/lib/providers/sam_base_provider.py | 25 +- samcli/lib/providers/sam_function_provider.py | 18 +- samcli/lib/providers/sam_layer_provider.py | 3 +- samcli/lib/providers/sam_stack_provider.py | 5 +- samcli/lib/sync/__init__.py | 0 .../lib/sync/continuous_sync_flow_executor.py | 130 ++++ samcli/lib/sync/exceptions.py | 173 +++++ samcli/lib/sync/flows/__init__.py | 0 .../lib/sync/flows/alias_version_sync_flow.py | 89 +++ .../flows/auto_dependency_layer_sync_flow.py | 137 ++++ samcli/lib/sync/flows/function_sync_flow.py | 104 +++ .../lib/sync/flows/generic_api_sync_flow.py | 89 +++ samcli/lib/sync/flows/http_api_sync_flow.py | 64 ++ .../sync/flows/image_function_sync_flow.py | 110 ++++ samcli/lib/sync/flows/layer_sync_flow.py | 348 ++++++++++ samcli/lib/sync/flows/rest_api_sync_flow.py | 64 ++ .../lib/sync/flows/stepfunctions_sync_flow.py | 113 ++++ .../lib/sync/flows/zip_function_sync_flow.py | 155 +++++ samcli/lib/sync/sync_flow.py | 295 +++++++++ samcli/lib/sync/sync_flow_executor.py | 342 ++++++++++ samcli/lib/sync/sync_flow_factory.py | 181 ++++++ samcli/lib/sync/watch_manager.py | 241 +++++++ samcli/lib/telemetry/metric.py | 10 +- samcli/lib/utils/architecture.py | 2 +- samcli/lib/utils/boto_utils.py | 96 +++ samcli/lib/utils/botoconfig.py | 17 - samcli/lib/utils/cloudformation.py | 127 ++++ samcli/lib/utils/code_trigger_factory.py | 112 ++++ samcli/lib/utils/colors.py | 10 + samcli/lib/utils/definition_validator.py | 60 ++ samcli/lib/utils/hash.py | 33 +- samcli/lib/utils/lock_distributor.py | 142 ++++ samcli/lib/utils/osutils.py | 4 + samcli/lib/utils/path_observer.py | 161 +++++ samcli/lib/utils/resource_trigger.py | 339 ++++++++++ .../lib/utils/resource_type_based_factory.py | 69 ++ .../_utils => lib/utils}/resources.py | 37 +- samcli/lib/utils/version_checker.py | 13 +- setup.py | 3 +- .../commands/cli/test_global_config.py | 103 +-- tests/integration/buildcmd/test_build_cmd.py | 30 + .../integration/pipeline/test_init_command.py | 4 +- tests/integration/telemetry/integ_base.py | 7 +- .../telemetry/test_experimental_metric.py | 153 +++++ tests/unit/cli/test_global_config.py | 394 +++++++---- tests/unit/cli/test_main.py | 6 +- .../unit/commands/_utils/test_experimental.py | 110 ++++ tests/unit/commands/_utils/test_options.py | 31 + tests/unit/commands/_utils/test_template.py | 2 +- .../commands/buildcmd/test_build_context.py | 354 +++++++++- tests/unit/commands/buildcmd/test_command.py | 179 +---- tests/unit/commands/deploy/test_command.py | 8 + .../commands/deploy/test_deploy_context.py | 61 ++ .../unit/commands/local/lib/test_provider.py | 267 +++++++- .../commands/local/lib/test_stack_provider.py | 2 +- tests/unit/commands/logs/test_command.py | 130 +++- .../commands/logs/test_console_consumers.py | 24 +- tests/unit/commands/logs/test_logs_context.py | 283 +++----- .../unit/commands/logs/test_puller_factory.py | 258 ++++++++ .../unit/commands/samconfig/test_samconfig.py | 164 ++++- tests/unit/commands/sync/__init__.py | 0 tests/unit/commands/sync/test_command.py | 614 ++++++++++++++++++ tests/unit/commands/traces/test_command.py | 69 ++ .../traces/test_trace_console_consumers.py | 14 + .../traces/test_traces_puller_factory.py | 87 +++ .../nested_stack/test_nested_stack_builder.py | 78 +++ .../nested_stack/test_nested_stack_manager.py | 207 ++++++ .../unit/lib/build_module/test_app_builder.py | 95 ++- .../unit/lib/build_module/test_build_graph.py | 281 ++++++-- .../lib/build_module/test_build_strategy.py | 183 +++++- .../test_dependency_hash_generator.py | 86 +++ tests/unit/lib/deploy/test_deployer.py | 105 +++ .../cw_logs/test_cw_log_formatters.py | 46 ++ .../cw_logs/test_cw_log_group_provider.py | 67 +- .../cw_logs/test_cw_log_puller.py | 59 ++ .../test_observability_info_puller.py | 109 +++- .../xray_traces/test_xray_event_mappers.py | 203 ++++++ .../xray_traces/test_xray_event_puller.py | 166 +++++ .../xray_traces/test_xray_events.py | 198 ++++++ .../test_xray_service_grpah_event_puller.py | 146 +++++ tests/unit/lib/sync/__init__.py | 0 tests/unit/lib/sync/flows/__init__.py | 0 .../flows/test_alias_version_sync_flow.py | 61 ++ .../test_auto_dependency_layer_sync_flow.py | 176 +++++ .../lib/sync/flows/test_function_sync_flow.py | 50 ++ .../lib/sync/flows/test_http_api_sync_flow.py | 109 ++++ .../flows/test_image_function_sync_flow.py | 124 ++++ .../lib/sync/flows/test_layer_sync_flow.py | 430 ++++++++++++ .../lib/sync/flows/test_rest_api_sync_flow.py | 100 +++ .../flows/test_stepfunctions_sync_flow.py | 125 ++++ .../sync/flows/test_zip_function_sync_flow.py | 180 +++++ .../test_continuous_sync_flow_executor.py | 144 ++++ tests/unit/lib/sync/test_exceptions.py | 50 ++ tests/unit/lib/sync/test_sync_flow.py | 119 ++++ .../unit/lib/sync/test_sync_flow_executor.py | 262 ++++++++ tests/unit/lib/sync/test_sync_flow_factory.py | 138 ++++ tests/unit/lib/sync/test_watch_manager.py | 239 +++++++ tests/unit/lib/telemetry/test_metric.py | 4 + tests/unit/lib/utils/test_boto_utils.py | 85 +++ tests/unit/lib/utils/test_cloudformation.py | 119 ++++ .../lib/utils/test_code_trigger_factory.py | 72 ++ .../lib/utils/test_definition_validator.py | 61 ++ tests/unit/lib/utils/test_handler_observer.py | 145 +++++ tests/unit/lib/utils/test_hash.py | 11 + tests/unit/lib/utils/test_lock_distributor.py | 103 +++ tests/unit/lib/utils/test_resource_trigger.py | 258 ++++++++ .../utils/test_resource_type_based_factory.py | 48 ++ tests/unit/lib/utils/test_version_checker.py | 29 +- 167 files changed, 16189 insertions(+), 1728 deletions(-) create mode 100644 samcli/commands/_utils/experimental.py create mode 100644 samcli/commands/logs/puller_factory.py create mode 100644 samcli/commands/sync/__init__.py create mode 100644 samcli/commands/sync/command.py create mode 100644 samcli/commands/traces/__init__.py create mode 100644 samcli/commands/traces/command.py create mode 100644 samcli/commands/traces/trace_console_consumers.py create mode 100644 samcli/commands/traces/traces_puller_factory.py create mode 100644 samcli/lib/bootstrap/nested_stack/__init__.py create mode 100644 samcli/lib/bootstrap/nested_stack/nested_stack_builder.py create mode 100644 samcli/lib/bootstrap/nested_stack/nested_stack_manager.py create mode 100644 samcli/lib/bootstrap/stack_builder.py create mode 100644 samcli/lib/build/dependency_hash_generator.py create mode 100644 samcli/lib/observability/xray_traces/__init__.py create mode 100644 samcli/lib/observability/xray_traces/xray_event_mappers.py create mode 100644 samcli/lib/observability/xray_traces/xray_event_puller.py create mode 100644 samcli/lib/observability/xray_traces/xray_events.py create mode 100644 samcli/lib/observability/xray_traces/xray_service_graph_event_puller.py create mode 100644 samcli/lib/sync/__init__.py create mode 100644 samcli/lib/sync/continuous_sync_flow_executor.py create mode 100644 samcli/lib/sync/exceptions.py create mode 100644 samcli/lib/sync/flows/__init__.py create mode 100644 samcli/lib/sync/flows/alias_version_sync_flow.py create mode 100644 samcli/lib/sync/flows/auto_dependency_layer_sync_flow.py create mode 100644 samcli/lib/sync/flows/function_sync_flow.py create mode 100644 samcli/lib/sync/flows/generic_api_sync_flow.py create mode 100644 samcli/lib/sync/flows/http_api_sync_flow.py create mode 100644 samcli/lib/sync/flows/image_function_sync_flow.py create mode 100644 samcli/lib/sync/flows/layer_sync_flow.py create mode 100644 samcli/lib/sync/flows/rest_api_sync_flow.py create mode 100644 samcli/lib/sync/flows/stepfunctions_sync_flow.py create mode 100644 samcli/lib/sync/flows/zip_function_sync_flow.py create mode 100644 samcli/lib/sync/sync_flow.py create mode 100644 samcli/lib/sync/sync_flow_executor.py create mode 100644 samcli/lib/sync/sync_flow_factory.py create mode 100644 samcli/lib/sync/watch_manager.py create mode 100644 samcli/lib/utils/boto_utils.py delete mode 100644 samcli/lib/utils/botoconfig.py create mode 100644 samcli/lib/utils/cloudformation.py create mode 100644 samcli/lib/utils/code_trigger_factory.py create mode 100644 samcli/lib/utils/definition_validator.py create mode 100644 samcli/lib/utils/lock_distributor.py create mode 100644 samcli/lib/utils/path_observer.py create mode 100644 samcli/lib/utils/resource_trigger.py create mode 100644 samcli/lib/utils/resource_type_based_factory.py rename samcli/{commands/_utils => lib/utils}/resources.py (85%) create mode 100644 tests/integration/telemetry/test_experimental_metric.py create mode 100644 tests/unit/commands/_utils/test_experimental.py create mode 100644 tests/unit/commands/logs/test_puller_factory.py create mode 100644 tests/unit/commands/sync/__init__.py create mode 100644 tests/unit/commands/sync/test_command.py create mode 100644 tests/unit/commands/traces/test_command.py create mode 100644 tests/unit/commands/traces/test_trace_console_consumers.py create mode 100644 tests/unit/commands/traces/test_traces_puller_factory.py create mode 100644 tests/unit/lib/bootstrap/nested_stack/test_nested_stack_builder.py create mode 100644 tests/unit/lib/bootstrap/nested_stack/test_nested_stack_manager.py create mode 100644 tests/unit/lib/build_module/test_dependency_hash_generator.py create mode 100644 tests/unit/lib/observability/xray_traces/test_xray_event_mappers.py create mode 100644 tests/unit/lib/observability/xray_traces/test_xray_event_puller.py create mode 100644 tests/unit/lib/observability/xray_traces/test_xray_events.py create mode 100644 tests/unit/lib/observability/xray_traces/test_xray_service_grpah_event_puller.py create mode 100644 tests/unit/lib/sync/__init__.py create mode 100644 tests/unit/lib/sync/flows/__init__.py create mode 100644 tests/unit/lib/sync/flows/test_alias_version_sync_flow.py create mode 100644 tests/unit/lib/sync/flows/test_auto_dependency_layer_sync_flow.py create mode 100644 tests/unit/lib/sync/flows/test_function_sync_flow.py create mode 100644 tests/unit/lib/sync/flows/test_http_api_sync_flow.py create mode 100644 tests/unit/lib/sync/flows/test_image_function_sync_flow.py create mode 100644 tests/unit/lib/sync/flows/test_layer_sync_flow.py create mode 100644 tests/unit/lib/sync/flows/test_rest_api_sync_flow.py create mode 100644 tests/unit/lib/sync/flows/test_stepfunctions_sync_flow.py create mode 100644 tests/unit/lib/sync/flows/test_zip_function_sync_flow.py create mode 100644 tests/unit/lib/sync/test_continuous_sync_flow_executor.py create mode 100644 tests/unit/lib/sync/test_exceptions.py create mode 100644 tests/unit/lib/sync/test_sync_flow.py create mode 100644 tests/unit/lib/sync/test_sync_flow_executor.py create mode 100644 tests/unit/lib/sync/test_sync_flow_factory.py create mode 100644 tests/unit/lib/sync/test_watch_manager.py create mode 100644 tests/unit/lib/utils/test_boto_utils.py create mode 100644 tests/unit/lib/utils/test_cloudformation.py create mode 100644 tests/unit/lib/utils/test_code_trigger_factory.py create mode 100644 tests/unit/lib/utils/test_definition_validator.py create mode 100644 tests/unit/lib/utils/test_handler_observer.py create mode 100644 tests/unit/lib/utils/test_lock_distributor.py create mode 100644 tests/unit/lib/utils/test_resource_trigger.py create mode 100644 tests/unit/lib/utils/test_resource_type_based_factory.py diff --git a/.gitignore b/.gitignore index 32aa6289e7..2f21084ff2 100644 --- a/.gitignore +++ b/.gitignore @@ -307,6 +307,11 @@ env.bak/ venv.bak/ venv-update-reproducible-requirements/ +env.*/ +venv.*/ +.env.*/ +.venv.*/ + # Spyder project settings .spyderproject .spyproject diff --git a/appveyor.yml b/appveyor.yml index 74d614643d..b2100c6b40 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -197,7 +197,7 @@ for: # Runs only in Linux, logging Public ECR when running canary and cred is available - sh: " if [[ -n $BY_CANARY ]]; - then echo Logging in Public ECR; aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws; + then echo Logging in Public ECR; aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws; fi" - sh: "pytest -vv tests/integration" diff --git a/requirements/base.txt b/requirements/base.txt index 9a410ba208..33494b8557 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -12,10 +12,15 @@ docker~=4.2.0 dateparser~=1.0 requests==2.25.1 serverlessrepo==0.1.10 -aws_lambda_builders==1.8.1 +aws_lambda_builders==1.9.0 tomlkit==0.7.2 watchdog==2.1.2 + +# Needed for supporting Protocol in Python 3.6 +typing_extensions==3.10.0.0 +# Needed for supporting dataclasses decorator in Python3.6 +dataclasses==0.8; python_version < '3.7' # NOTE: regex is not a direct dependency of SAM CLI, but pin to 2021.9.30 due to 2021.10.8 not working on M1 Mac - https://bitbucket.org/mrabarnett/mrab-regex/issues/399/missing-wheel-for-macosx-and-the-new-m1 regex==2021.9.30 # NOTE: tzlocal is not a direct dependency of SAM CLI, but pin to 3.0 as 4.0 break appveyor jobs -tzlocal==3.0 \ No newline at end of file +tzlocal==3.0 diff --git a/requirements/reproducible-linux.txt b/requirements/reproducible-linux.txt index 01719e6966..da08d08a37 100644 --- a/requirements/reproducible-linux.txt +++ b/requirements/reproducible-linux.txt @@ -12,10 +12,10 @@ attrs==20.3.0 \ --hash=sha256:31b2eced602aa8423c2aea9c76a724617ed67cf9513173fd3a4f03e3a929c7e6 \ --hash=sha256:832aa3cde19744e49938b91fea06d69ecb9e649c93ba974535d08ad92164f700 # via jsonschema -aws-lambda-builders==1.8.1 \ - --hash=sha256:21397c8596415506f3c450bbded5a5d361ba40af782034f3b58c17f088cd1410 \ - --hash=sha256:7c2672cb9c0b4a5f5f24707ed25808e4920fc87c0ebec4d76dd6c50d02f3aa47 \ - --hash=sha256:8fd7be03216fe9eee81875bef5e8d73fd7f5ddbc832646bc0d625a21be75a835 +aws-lambda-builders==1.9.0 \ + --hash=sha256:1498fb374d3f2e289a78a4843cfcee86f777011cd8d3c29f38826c16239839f9 \ + --hash=sha256:335966e9dc19ab37ad9a600dbcf90ecb8d47f1fcffbf83ef6eebe61fcbdd6731 \ + --hash=sha256:742a67123dc0ea91983e7a6947dad0f23dff9c958a28da46162ee675ae257273 # via aws-sam-cli (setup.py) aws-sam-translator==1.39.0 \ --hash=sha256:410d11a14a71f5ecab9cd1c5c4415b3c2419b8a8ee4fb9e887f15bdcc6c7ac38 \ @@ -318,12 +318,13 @@ tomlkit==0.7.2 \ --hash=sha256:173ad840fa5d2aac140528ca1933c29791b79a374a0861a80347f42ec9328117 \ --hash=sha256:d7a454f319a7e9bd2e249f239168729327e4dd2d27b17dc68be264ad1ce36754 # via aws-sam-cli (setup.py) -typing-extensions==3.10.0.2 \ - --hash=sha256:49f75d16ff11f1cd258e1b988ccff82a3ca5570217d7ad8c5f48205dd99a677e \ - --hash=sha256:d8226d10bc02a29bcc81df19a26e56a9647f8b0a6d4a83924139f4a8b01f17b7 \ - --hash=sha256:f1d25edafde516b146ecd0613dabcc61409817af4766fbbcfb8d1ad4ec441a34 +typing-extensions==3.10.0.0 \ + --hash=sha256:0ac0f89795dd19de6b97debb0c6af1c70987fd80a2d62d1958f7e56fcc31b497 \ + --hash=sha256:50b6f157849174217d0656f99dc82fe932884fb250826c18350e159ec6cdf342 \ + --hash=sha256:779383f6086d90c99ae41cf0ff39aac8a7937a9283ce0a414e5dd782f4c94a84 # via # arrow + # aws-sam-cli (setup.py) # importlib-metadata tzlocal==3.0 \ --hash=sha256:c736f2540713deb5938d789ca7c3fc25391e9a20803f05b60ec64987cf086559 \ diff --git a/samcli/cli/command.py b/samcli/cli/command.py index 0741cda84f..c0465db511 100644 --- a/samcli/cli/command.py +++ b/samcli/cli/command.py @@ -22,6 +22,8 @@ "samcli.commands.delete", "samcli.commands.logs", "samcli.commands.publish", + "samcli.commands.traces", + "samcli.commands.sync", "samcli.commands.pipeline.pipeline", # We intentionally do not expose the `bootstrap` command for now. We might open it up later # "samcli.commands.bootstrap", diff --git a/samcli/cli/context.py b/samcli/cli/context.py index 74c35155a1..53c2e93589 100644 --- a/samcli/cli/context.py +++ b/samcli/cli/context.py @@ -44,6 +44,7 @@ def __init__(self): self._aws_region = None self._aws_profile = None self._session_id = str(uuid.uuid4()) + self._experimental = False @property def debug(self): @@ -97,6 +98,14 @@ def session_id(self) -> str: """ return self._session_id + @property + def experimental(self): + return self._experimental + + @experimental.setter + def experimental(self, value): + self._experimental = value + @property def command_path(self): """ diff --git a/samcli/cli/global_config.py b/samcli/cli/global_config.py index 92d4d6a181..9ea764b394 100644 --- a/samcli/cli/global_config.py +++ b/samcli/cli/global_config.py @@ -1,55 +1,328 @@ """ Provides global configuration helpers. """ - import json import logging import uuid import os +import threading + from pathlib import Path -from typing import Optional, Dict, Any +from typing import List, Optional, Dict, Any, Type, TypeVar, cast, overload +from dataclasses import dataclass import click LOG = logging.getLogger(__name__) -CONFIG_FILENAME = "metadata.json" -INSTALLATION_ID_KEY = "installationId" -TELEMETRY_ENABLED_KEY = "telemetryEnabled" -LAST_VERSION_CHECK_KEY = "lastVersionCheck" +@dataclass(frozen=True, eq=True) +class ConfigEntry: + """Data class for storing configuration related keys""" + + config_key: Optional[str] + env_var_key: Optional[str] + persistent: bool = True + + +class DefaultEntry: + """Set of default configuration entries integrated with GlobalConfig""" + + INSTALLATION_ID = ConfigEntry("installationId", None) + LAST_VERSION_CHECK = ConfigEntry("lastVersionCheck", None) + TELEMETRY = ConfigEntry("telemetryEnabled", "SAM_CLI_TELEMETRY") + + +class Singleton(type): + def __init__(cls, *args, **kwargs): + super().__init__(*args, **kwargs) + cls.__instance = None -class GlobalConfig: + def __call__(cls, *args, **kwargs): + if cls.__instance is None: + cls.__instance = super().__call__(*args, **kwargs) + return cls.__instance + + +class GlobalConfig(metaclass=Singleton): """ - Contains helper methods for global configuration files and values. Handles - configuration file creation, updates, and fetching in a platform-neutral way. + A singleton for accessing configurations from environmental variables and + configuration file. Singleton is used to enforce immutability, access locking, + and rapid configuration modification. Generally uses '~/.aws-sam/' or 'C:\\Users\\\\AppData\\Roaming\\AWS SAM' as the base directory, depending on platform. """ - def __init__(self, config_dir=None, installation_id=None, telemetry_enabled=None, last_version_check=None): - """ - Initializes the class, with options provided to assist with testing. + DEFAULT_CONFIG_FILENAME: str = "metadata.json" - :param config_dir: Optional, overrides the default config directory path. - :param installation_id: Optional, will use this installation id rather than checking config values. - :param telemetry_enabled: Optional, set whether telemetry is enabled or not. - :param last_version_check: Optional, will be used to check if there is a newer version of SAM CLI available - """ - self._config_dir = config_dir - self._installation_id = installation_id - self._telemetry_enabled = telemetry_enabled - self._last_version_check = last_version_check + # Env var for injecting dir in integration tests + _DIR_INJECTION_ENV_VAR: str = "__SAM_CLI_APP_DIR" + + # Static singleton instance + + _access_lock: threading.RLock + _config_dir: Optional[Path] + _config_filename: Optional[str] + # Dictionary storing config data mapped directly to the content of the config file + _config_data: Optional[Dict[str, Any]] + # config_keys that should be flushed to file + _persistent_fields: List[str] + + def __init__(self): + """__init__ should only be called once due to Singleton metaclass""" + self._access_lock = threading.RLock() + self._config_dir = None + self._config_filename = None + self._config_data = None + self._persistent_fields = list() @property def config_dir(self) -> Path: + """ + Returns + ------- + Path + Path object for the configuration directory. + """ if not self._config_dir: - # Internal Environment variable to customize SAM CLI App Dir. Currently used only by integ tests. - app_dir = os.getenv("__SAM_CLI_APP_DIR") - self._config_dir = Path(app_dir) if app_dir else Path(click.get_app_dir("AWS SAM", force_posix=True)) - return Path(self._config_dir) + if GlobalConfig._DIR_INJECTION_ENV_VAR in os.environ: + # Set dir to the one specified in _DIR_INJECTION_ENV_VAR environmental variable + # This is used for existing integration tests + env_var_path = os.environ.get(GlobalConfig._DIR_INJECTION_ENV_VAR) + self._config_dir = Path(cast(str, env_var_path)) + else: + self._config_dir = Path(click.get_app_dir("AWS SAM", force_posix=True)) + return self._config_dir + + @config_dir.setter + def config_dir(self, dir_path: Path) -> None: + """ + Parameters + ---------- + dir_path : Path + Directory path object for the configuration. + + Raises + ------ + ValueError + ValueError will be raised if the path is not a directory. + """ + if not dir_path.is_dir(): + raise ValueError("config_dir must be a directory.") + self._config_dir = dir_path + self._config_data = None + + @property + def config_filename(self) -> str: + """ + Returns + ------- + str + Filename for the configuration. + """ + if not self._config_filename: + self._config_filename = GlobalConfig.DEFAULT_CONFIG_FILENAME + return self._config_filename + + @config_filename.setter + def config_filename(self, filename: str) -> None: + self._config_filename = filename + self._config_data = None + + @property + def config_path(self) -> Path: + """ + Returns + ------- + Path + Path object for the configuration file (config_dir + config_filename). + """ + return Path(self.config_dir, self.config_filename) + + T = TypeVar("T") + # Overloads are only used for type hinting. + # Overload for case where is_flag is set + @overload + def get_value( + self, + config_entry: ConfigEntry, + default: bool, + value_type: Type[bool], + is_flag: bool, + reload_config: bool = False, + ) -> bool: + ... + + # Overload for case where type is specified + @overload + def get_value( + self, + config_entry: ConfigEntry, + default: Optional[T] = None, + value_type: Type[T] = T, + is_flag: bool = False, + reload_config: bool = False, + ) -> Optional[T]: + ... + + # Overload for case where type is not specified and default to object + @overload + def get_value( + self, + config_entry: ConfigEntry, + default: Any = None, + value_type: object = object, + is_flag: bool = False, + reload_config: bool = False, + ) -> Any: + ... + + def get_value( + self, + config_entry, + default=None, + value_type=object, + is_flag=False, + reload_config=False, + ) -> Any: + """Get the corresponding value of a configuration entry. + + Parameters + ---------- + config_entry : ConfigEntry + Configuration entry for which the value will be loaded. + default : value_type, optional + The default value to be returned if the configuration does not exist, + encountered an error, or in the incorrect type. + By default None + value_type : Type, optional + The type of the value that should be expected. + If the value is not this type, default will be returned. + By default object + is_flag : bool, optional + If is_flag is True, then env var will be set to "1" or "0" instead of boolean values. + This is useful for backward compatibility with the old configuration format where + configuration file and env var has different values. + By default False + reload_config : bool, optional + Whether configuration file should be reloaded before getting the value. + By default False + + Returns + ------- + [value_type] + Value in the type specified by value_type + """ + with self._access_lock: + return self._get_value(config_entry, default, value_type, is_flag, reload_config) + + def _get_value( + self, + config_entry: ConfigEntry, + default: Optional[T], + value_type: Type[T], + is_flag: bool, + reload_config: bool, + ) -> Optional[T]: + """get_value without locking. Non-thread safe.""" + value: Any = None + try: + if config_entry.env_var_key: + value = os.environ.get(config_entry.env_var_key) + if value is not None and is_flag: + value = value == "1" + + if value is None and config_entry.config_key: + if reload_config or self._config_data is None: + self._load_config() + value = cast(dict, self._config_data).get(config_entry.config_key) + + if value is None or not isinstance(value, value_type): + return default + except (ValueError, OSError) as ex: + LOG.debug( + "Error when retrieving config_key: %s env_var_key: %s", + config_entry.config_key, + config_entry.env_var_key, + exc_info=ex, + ) + return default + + return value + + def set_value(self, config_entry: ConfigEntry, value: Any, is_flag: bool = False, flush: bool = True) -> None: + """Set the value of a configuration. The associated env var will be updated as well. + + Parameters + ---------- + config_entry : ConfigEntry + Configuration entry to be set + value : Any + Value of the configuration + is_flag : bool, optional + If is_flag is True, then env var will be set to "1" or "0" instead of boolean values. + This is useful for backward compatibility with the old configuration format where + configuration file and env var has different values. + By default False + flush : bool, optional + Should the value be written to configuration file, by default True + """ + with self._access_lock: + self._set_value(config_entry, value, is_flag, flush) + + def _set_value(self, config_entry: ConfigEntry, value: Any, is_flag: bool, flush: bool) -> None: + """set_value without locking. Non-thread safe.""" + if config_entry.env_var_key: + if is_flag: + os.environ[config_entry.env_var_key] = "1" if value else "0" + else: + os.environ[config_entry.env_var_key] = value + + if config_entry.config_key: + if self._config_data is None: + self._load_config() + cast(dict, self._config_data)[config_entry.config_key] = value + + if config_entry.persistent: + self._persistent_fields.append(config_entry.config_key) + elif config_entry.config_key in self._persistent_fields: + self._persistent_fields.remove(config_entry.config_key) + + if flush: + self._write_config() + + def _load_config(self) -> None: + """Reload configurations from file and populate self._config_data""" + if not self.config_path.exists(): + self._config_data = {} + return + try: + body = self.config_path.read_text() + json_body = json.loads(body) + self._config_data = json_body + # Default existing fields to be persistent + # so that they will be kept when flushed back + for key in json_body: + self._persistent_fields.append(key) + except (OSError, ValueError) as ex: + LOG.debug( + "Error when loading global config file: %s", + self.config_path, + exc_info=ex, + ) + self._config_data = {} + + def _write_config(self) -> None: + """Write configurations in self._config_data to file""" + if not self._config_data: + return + config_data = {key: value for (key, value) in self._config_data.items() if key in self._persistent_fields} + json_str = json.dumps(config_data, indent=4) + if not self.config_dir.exists(): + self.config_dir.mkdir(mode=0o700, parents=True, exist_ok=True) + self.config_path.write_text(json_str) @property def installation_id(self): @@ -72,22 +345,20 @@ def installation_id(self): ------- A string containing the installation UUID, or None in case of an error. """ - if self._installation_id: - return self._installation_id - try: - self._installation_id = self._get_or_set_uuid(INSTALLATION_ID_KEY) - return self._installation_id - except (ValueError, IOError, OSError): - return None + value = self.get_value(DefaultEntry.INSTALLATION_ID, default=None, value_type=str, reload_config=True) + if not value: + value = str(uuid.uuid4()) + self.set_value(DefaultEntry.INSTALLATION_ID, value) + return value @property - def telemetry_enabled(self): + def telemetry_enabled(self) -> Optional[bool]: """ Check if telemetry is enabled for this installation. Default value of False. It first tries to get value from SAM_CLI_TELEMETRY environment variable. If its not set, then it fetches the value from config file. - To enable telemetry, set SAM_CLI_TELEMETRY environment variable equal to integer 1 or string '1'. + To enable telemetry, set SAM_CLI_TELEMETRY environment variable equal to string '1'. All other values including words like 'True', 'true', 'false', 'False', 'abcd' etc will disable Telemetry Examples @@ -102,23 +373,10 @@ def telemetry_enabled(self): Boolean flag value. True if telemetry is enabled for this installation, False otherwise. """ - if self._telemetry_enabled is not None: - return self._telemetry_enabled - - # If environment variable is set, its value takes precedence over the value from config file. - env_name = "SAM_CLI_TELEMETRY" - if env_name in os.environ: - return os.getenv(env_name) in ("1", 1) - - try: - self._telemetry_enabled = self._get_value(TELEMETRY_ENABLED_KEY) - return self._telemetry_enabled - except (ValueError, IOError, OSError) as ex: - LOG.debug("Error when retrieving telemetry_enabled flag", exc_info=ex) - return False + return self.get_value(DefaultEntry.TELEMETRY, default=None, value_type=bool, is_flag=True) @telemetry_enabled.setter - def telemetry_enabled(self, value): + def telemetry_enabled(self, value: bool) -> None: """ Sets the telemetry_enabled flag to the provided boolean value. @@ -139,90 +397,12 @@ def telemetry_enabled(self, value): JSONDecodeError If the config file exists, and is not valid JSON. """ - self._set_value("telemetryEnabled", value) - self._telemetry_enabled = value + self.set_value(DefaultEntry.TELEMETRY, value, is_flag=True, flush=True) @property - def last_version_check(self): - if self._last_version_check is not None: - return self._last_version_check - - try: - self._last_version_check = self._get_value(LAST_VERSION_CHECK_KEY) - return self._last_version_check - except (ValueError, IOError, OSError) as ex: - LOG.debug("Error when retrieving _last_version_check flag", exc_info=ex) - return None + def last_version_check(self) -> Optional[float]: + return self.get_value(DefaultEntry.LAST_VERSION_CHECK, value_type=float) @last_version_check.setter - def last_version_check(self, value): - self._set_value(LAST_VERSION_CHECK_KEY, value) - self._last_version_check = value - - def _get_value(self, key: str) -> Optional[Any]: - cfg_path = self._get_config_file_path(CONFIG_FILENAME) - if not cfg_path.exists(): - return None - with open(str(cfg_path)) as fp: - body = fp.read() - json_body = json.loads(body) - return json_body.get(key) - - def _set_value(self, key: str, value: Any) -> Any: - cfg_path = self._get_config_file_path(CONFIG_FILENAME) - if not cfg_path.exists(): - return self._set_json_cfg(cfg_path, key, value) - with open(str(cfg_path)) as fp: - body = fp.read() - try: - json_body = json.loads(body) - except ValueError as ex: - LOG.debug("Failed to decode JSON in {cfg_path}", exc_info=ex) - raise ex - return self._set_json_cfg(cfg_path, key, value, json_body) - - def _create_dir(self): - """ - Creates configuration directory if it does not already exist, otherwise does nothing. - May raise an OSError if we do not have permissions to create the directory. - """ - self.config_dir.mkdir(mode=0o700, parents=True, exist_ok=True) - - def _get_config_file_path(self, filename): - self._create_dir() - filepath = self.config_dir.joinpath(filename) - return filepath - - def _get_or_set_uuid(self, key): - """ - Special logic method for when we want a UUID to always be present, this - method behaves as a getter with side effects. Essentially, if the value - is not present, we will set it with a generated UUID. - - If we have multiple such values in the future, a possible refactor is - to just be _get_or_set_value, where we also take a default value as a - parameter. - """ - cfg_value = self._get_value(key) - if cfg_value is not None: - return cfg_value - return self._set_value(key, str(uuid.uuid4())) - - @staticmethod - def _set_json_cfg(filepath: Path, key: str, value: Any, json_body: Optional[Dict] = None) -> Any: - """ - Special logic method to add a value to a JSON configuration file. This - method will write a new version of the file in question, so it will - either write a new file with only the first config value, or if a JSON - body is provided, it will upsert starting from that JSON body. - """ - json_body = json_body or {} - json_body[key] = value - file_body = json.dumps(json_body, indent=4) + "\n" - try: - with open(str(filepath), "w") as f: - f.write(file_body) - except IOError as ex: - LOG.debug("Error writing to {filepath}", exc_info=ex) - raise ex - return value + def last_version_check(self, value: float): + self.set_value(DefaultEntry.LAST_VERSION_CHECK, value) diff --git a/samcli/cli/main.py b/samcli/cli/main.py index 14a0f6c003..5810bdac58 100644 --- a/samcli/cli/main.py +++ b/samcli/cli/main.py @@ -27,9 +27,6 @@ pass_context = click.make_pass_decorator(Context) -global_cfg = GlobalConfig() - - def common_options(f): """ Common CLI options used by all commands. Ex: --debug @@ -118,11 +115,12 @@ def cli(ctx): You can find more in-depth guide about the SAM specification here: https://github.com/awslabs/serverless-application-model. """ - if global_cfg.telemetry_enabled is None: + gc = GlobalConfig() + if gc.telemetry_enabled is None: enabled = True try: - global_cfg.telemetry_enabled = enabled + gc.telemetry_enabled = enabled if enabled: click.secho(TELEMETRY_PROMPT, fg="yellow", err=True) diff --git a/samcli/commands/_utils/experimental.py b/samcli/commands/_utils/experimental.py new file mode 100644 index 0000000000..d6b86f8512 --- /dev/null +++ b/samcli/commands/_utils/experimental.py @@ -0,0 +1,229 @@ +"""Experimental flag""" +import sys +import logging + +from dataclasses import dataclass +from functools import wraps +from typing import List, Dict, Optional + +import click + +from samcli.cli.context import Context + +from samcli.cli.global_config import ConfigEntry, GlobalConfig +from samcli.commands._utils.options import parameterized_option +from samcli.lib.utils.colors import Colored + +LOG = logging.getLogger(__name__) + +EXPERIMENTAL_PROMPT = """ +This feature is currently in beta. Visit the docs page to learn more about the AWS Beta terms https://aws.amazon.com/service-terms/. +Enter Y to proceed with the command, or enter N to cancel: +""" + +EXPERIMENTAL_WARNING = """ +Experimental features are enabled for this session. +Visit the docs page to learn more about the AWS Beta terms https://aws.amazon.com/service-terms/. +""" + +EXPERIMENTAL_ENV_VAR_PREFIX = "SAM_CLI_BETA_" + + +@dataclass(frozen=True, eq=True) +class ExperimentalEntry(ConfigEntry): + """Child data class of ConfigEntry that enforces + config_key and env_var_key to be not None""" + + config_key: str + env_var_key: str + persistent: bool = False + + +class ExperimentalFlag: + """Class for storing all experimental related ConfigEntries""" + + All = ExperimentalEntry("experimentalAll", EXPERIMENTAL_ENV_VAR_PREFIX + "FEATURES") + Accelerate = ExperimentalEntry("experimentalAccelerate", EXPERIMENTAL_ENV_VAR_PREFIX + "ACCELERATE") + + +def is_experimental_enabled(config_entry: ExperimentalEntry) -> bool: + """Whether a given experimental flag is enabled or not. + If experimentalAll is set to True, then it will always return True. + + Parameters + ---------- + config_entry : ExperimentalEntry + Experimental flag ExperimentalEntry + + Returns + ------- + bool + Whether the experimental flag is enabled or not. + """ + gc = GlobalConfig() + enabled = gc.get_value(config_entry, default=False, value_type=bool, is_flag=True) + if not enabled: + enabled = gc.get_value(ExperimentalFlag.All, default=False, value_type=bool, is_flag=True) + return enabled + + +def set_experimental(config_entry: ExperimentalEntry = ExperimentalFlag.All, enabled: bool = True) -> None: + """Set the experimental flag to enabled or disabled. + + Parameters + ---------- + config_entry : ExperimentalEntry, optional + Flag to be set, by default ExperimentalFlag.All + enabled : bool, optional + Enabled or disabled, by default True + """ + gc = GlobalConfig() + gc.set_value(config_entry, enabled, is_flag=True, flush=False) + + +def get_all_experimental() -> List[ExperimentalEntry]: + """ + Returns + ------- + List[ExperimentalEntry] + List all experimental flags in the ExperimentalFlag class. + """ + return [getattr(ExperimentalFlag, name) for name in dir(ExperimentalFlag) if not name.startswith("__")] + + +def get_all_experimental_statues() -> Dict[str, bool]: + """Get statues of all experimental flags in a dictionary. + + Returns + ------- + Dict[str, bool] + Dictionary with key as configuration value and value as enabled or disabled. + """ + return {entry.config_key: is_experimental_enabled(entry) for entry in get_all_experimental() if entry.config_key} + + +def disable_all_experimental(): + """Turn off all experimental flags in the ExperimentalFlag class.""" + for entry in get_all_experimental(): + set_experimental(entry, False) + + +def _update_experimental_context(show_warning=True): + """Set experimental for the current click context. + + Parameters + ---------- + show_warning : bool, optional + Should warning be shown, by default True + """ + if not Context.get_current_context().experimental: + Context.get_current_context().experimental = True + if show_warning: + LOG.warning(Colored().yellow(EXPERIMENTAL_WARNING)) + + +def _experimental_option_callback(ctx, param, enabled: Optional[bool]): + """Click parameter callback for --beta-features or --no-beta-features. + If neither is specified, enabled will be None. + If --beta-features is set, enabled will be True, + we should turn on all experimental flags. + If --no-beta-features is set, enabled will be False, + we should turn off all experimental flags, overriding existing env vars. + """ + if enabled is None: + return + + if enabled: + set_experimental(ExperimentalFlag.All, True) + else: + disable_all_experimental() + + +def experimental_click_option(default: Optional[bool]): + return click.option( + "--beta-features/--no-beta-features", + default=default, + required=False, + is_flag=True, + expose_value=False, + callback=_experimental_option_callback, + help="Should beta features be enabled.", + ) + + +@parameterized_option +def experimental(f, default: Optional[bool] = None): + """Decorator for adding --beta-features and --no-beta-features click options to a command.""" + return experimental_click_option(default)(f) + + +@parameterized_option +def force_experimental( + f, config_entry: ExperimentalEntry = ExperimentalFlag.All, prompt=EXPERIMENTAL_PROMPT, default=None +): + """Decorator for adding --beta-features and --no-beta-features click options to a command. + If experimental flag env var or --beta-features flag is not specified, this will then + prompt the user for confirmation. + The program will exit if confirmation is denied. + """ + + def wrap(func): + @wraps(func) + def wrapped_func(*args, **kwargs): + if not prompt_experimental(config_entry=config_entry, prompt=prompt): + sys.exit(1) + _update_experimental_context() + return func(*args, **kwargs) + + return wrapped_func + + return experimental_click_option(default)(wrap(f)) + + +@parameterized_option +def force_experimental_option( + f, option: str, config_entry: ExperimentalEntry = ExperimentalFlag.All, prompt=EXPERIMENTAL_PROMPT +): + """Decorator for making a specific option to be experimental. + A prompt will be shown if experimental is not enabled and the option is specified. + """ + + def wrap(func): + @wraps(func) + def wrapped_func(*args, **kwargs): + if kwargs[option]: + if not prompt_experimental(config_entry=config_entry, prompt=prompt): + sys.exit(1) + _update_experimental_context() + return func(*args, **kwargs) + + return wrapped_func + + return wrap(f) + + +def prompt_experimental( + config_entry: ExperimentalEntry = ExperimentalFlag.All, prompt: str = EXPERIMENTAL_PROMPT +) -> bool: + """Prompt the user for experimental features. + If the corresponding experimental flag is already specified, the prompt will be skipped. + If confirmation is granted, the corresponding experimental flag env var will be set. + + Parameters + ---------- + config_entry : ExperimentalEntry, optional + Which experimental flag should be set, by default ExperimentalFlag.All + prompt : str, optional + Text to be shown in the prompt, by default EXPERIMENTAL_PROMPT + + Returns + ------- + bool + Whether user have accepted the experimental feature. + """ + if is_experimental_enabled(config_entry): + return True + confirmed = click.confirm(prompt, default=False) + if confirmed: + set_experimental(config_entry=config_entry, enabled=True) + return confirmed diff --git a/samcli/commands/_utils/options.py b/samcli/commands/_utils/options.py index 70ca59657c..23904e2b93 100644 --- a/samcli/commands/_utils/options.py +++ b/samcli/commands/_utils/options.py @@ -5,21 +5,70 @@ import os import logging from functools import partial +import types import click from click.types import FuncParamType from samcli.commands._utils.template import get_template_data, TemplateNotFoundException -from samcli.cli.types import CfnParameterOverridesType, CfnMetadataType, CfnTags, SigningProfilesOptionType +from samcli.cli.types import ( + CfnParameterOverridesType, + CfnMetadataType, + CfnTags, + SigningProfilesOptionType, + ImageRepositoryType, + ImageRepositoriesType, +) from samcli.commands._utils.custom_options.option_nargs import OptionNargs from samcli.commands._utils.template import get_template_artifacts_format +from samcli.lib.utils.packagetype import ZIP, IMAGE _TEMPLATE_OPTION_DEFAULT_VALUE = "template.[yaml|yml|json]" DEFAULT_STACK_NAME = "sam-app" +DEFAULT_BUILD_DIR = os.path.join(".aws-sam", "build") +DEFAULT_BUILD_DIR_WITH_AUTO_DEPENDENCY_LAYER = os.path.join(".aws-sam", "auto-dependency-layer") +DEFAULT_CACHE_DIR = os.path.join(".aws-sam", "cache") LOG = logging.getLogger(__name__) +def parameterized_option(option): + """Meta decorator for option decorators. + This adds the ability to specify optional parameters for option decorators. + + Usage: + @parameterized_option + def some_option(f, required=False) + ... + + @some_option + def command(...) + + or + + @some_option(required=True) + def command(...) + """ + + def parameter_wrapper(*args, **kwargs): + if len(args) == 1 and isinstance(args[0], types.FunctionType): + # Case when option decorator does not have parameter + # @stack_name_option + # def command(...) + return option(args[0]) + + # Case when option decorator does have parameter + # @stack_name_option("a", "b") + # def command(...) + + def option_wrapper(f): + return option(f, *args, **kwargs) + + return option_wrapper + + return parameter_wrapper + + def get_or_default_template_file_name(ctx, param, provided_value, include_build): """ Default value for the template file name option is more complex than what Click can handle. @@ -296,6 +345,47 @@ def signing_profiles_option(f): return signing_profiles_click_option()(f) +def common_observability_click_options(): + return [ + click.option( + "--start-time", + "-s", + default="10m ago", + help="Fetch events starting at this time. Time can be relative values like '5mins ago', 'yesterday' or " + "formatted timestamp like '2018-01-01 10:10:10'. Defaults to '10mins ago'.", + ), + click.option( + "--end-time", + "-e", + default=None, + help="Fetch events up to this time. Time can be relative values like '5mins ago', 'tomorrow' or " + "formatted timestamp like '2018-01-01 10:10:10'", + ), + click.option( + "--tail", + "-t", + is_flag=True, + help="Tail events. This will ignore the end time argument and continue to fetch events as they " + "become available. [Beta Feature] If in beta --tail without a --name will pull from all possible resources", + ), + click.option( + "--unformatted", + "-u", + is_flag=True, + help="[Beta Feature] " + "Print events without any text formatting in JSON. This option might be useful if you are reading " + "output into another tool.", + ), + ] + + +def common_observability_options(f): + for option in common_observability_click_options(): + option(f) + + return f + + def metadata_click_option(): return click.option( "--metadata", @@ -304,31 +394,33 @@ def metadata_click_option(): ) -def metadata_override_option(f): +def metadata_option(f): return metadata_click_option()(f) -def capabilities_click_option(): +def capabilities_click_option(default): return click.option( "--capabilities", cls=OptionNargs, required=False, + default=default, type=FuncParamType(func=_space_separated_list_func_type), - help="A list of capabilities that you must specify" - "before AWS Cloudformation can create certain stacks. Some stack tem-" - "plates might include resources that can affect permissions in your AWS" - "account, for example, by creating new AWS Identity and Access Manage-" - "ment (IAM) users. For those stacks, you must explicitly acknowledge" - "their capabilities by specifying this parameter. The only valid values" - "are CAPABILITY_IAM and CAPABILITY_NAMED_IAM. If you have IAM resources," - "you can specify either capability. If you have IAM resources with cus-" - "tom names, you must specify CAPABILITY_NAMED_IAM. If you don't specify" - "this parameter, this action returns an InsufficientCapabilities error.", + help="A list of capabilities that you must specify " + "before AWS Cloudformation can create certain stacks. Some stack templates " + "might include resources that can affect permissions in your AWS " + "account, for example, by creating new AWS Identity and Access Management " + "(IAM) users. For those stacks, you must explicitly acknowledge " + "their capabilities by specifying this parameter. The only valid values" + "are CAPABILITY_IAM and CAPABILITY_NAMED_IAM. If you have IAM resources, " + "you can specify either capability. If you have IAM resources with custom " + "names, you must specify CAPABILITY_NAMED_IAM. If you don't specify " + "this parameter, this action returns an InsufficientCapabilities error.", ) -def capabilities_override_option(f): - return capabilities_click_option()(f) +@parameterized_option +def capabilities_option(f, default=None): + return capabilities_click_option(default)(f) def tags_click_option(): @@ -343,7 +435,7 @@ def tags_click_option(): ) -def tags_override_option(f): +def tags_option(f): return tags_click_option()(f) @@ -359,10 +451,263 @@ def notification_arns_click_option(): ) -def notification_arns_override_option(f): +def notification_arns_option(f): return notification_arns_click_option()(f) +def stack_name_click_option(required): + return click.option( + "--stack-name", + required=required, + help="The name of the AWS CloudFormation stack you're deploying to. " + "If you specify an existing stack, the command updates the stack. " + "If you specify a new stack, the command creates it.", + ) + + +@parameterized_option +def stack_name_option(f, required=False): + return stack_name_click_option(required)(f) + + +def s3_bucket_click_option(guided): + callback = None if guided else partial(artifact_callback, artifact=ZIP) + return click.option( + "--s3-bucket", + required=False, + callback=callback, + help="The name of the S3 bucket where this command uploads the artifacts that are referenced in your template.", + ) + + +@parameterized_option +def s3_bucket_option(f, guided=False): + return s3_bucket_click_option(guided)(f) + + +def build_dir_click_option(): + return click.option( + "--build-dir", + "-b", + default=DEFAULT_BUILD_DIR, + type=click.Path(file_okay=False, dir_okay=True, writable=True), # Must be a directory + help="Path to a folder where the built artifacts will be stored. " + "This directory will be first removed before starting a build.", + ) + + +def build_dir_option(f): + return build_dir_click_option()(f) + + +def cache_dir_click_option(): + return click.option( + "--cache-dir", + "-cd", + default=DEFAULT_CACHE_DIR, + type=click.Path(file_okay=False, dir_okay=True, writable=True), # Must be a directory + help="The folder where the cache artifacts will be stored when --cached is specified. " + "The default cache directory is .aws-sam/cache", + ) + + +def cache_dir_option(f): + return cache_dir_click_option()(f) + + +def base_dir_click_option(): + return click.option( + "--base-dir", + "-s", + default=None, + type=click.Path(dir_okay=True, file_okay=False), # Must be a directory + help="Resolve relative paths to function's source code with respect to this folder. Use this if " + "SAM template and your source code are not in same enclosing folder. By default, relative paths " + "are resolved with respect to the SAM template's location", + ) + + +def base_dir_option(f): + return base_dir_click_option()(f) + + +def manifest_click_option(): + return click.option( + "--manifest", + "-m", + default=None, + type=click.Path(), + help="Path to a custom dependency manifest (e.g., package.json) to use instead of the default one", + ) + + +def manifest_option(f): + return manifest_click_option()(f) + + +def cached_click_option(): + return click.option( + "--cached", + "-c", + is_flag=True, + help="Enable cached builds. Use this flag to reuse build artifacts that have not changed from previous builds. " + "AWS SAM evaluates whether you have made any changes to files in your project directory. \n\n" + "Note: AWS SAM does not evaluate whether changes have been made to third party modules " + "that your project depends on, where you have not provided a specific version. " + "For example, if your Python function includes a requirements.txt file with the following entry " + "requests=1.x and the latest request module version changes from 1.1 to 1.2, " + "SAM will not pull the latest version until you run a non-cached build.", + ) + + +def cached_option(f): + return cached_click_option()(f) + + +def image_repository_click_option(): + return click.option( + "--image-repository", + callback=partial(artifact_callback, artifact=IMAGE), + type=ImageRepositoryType(), + required=False, + help="ECR repo uri where this command uploads the image artifacts that are referenced in your template.", + ) + + +def image_repository_option(f): + return image_repository_click_option()(f) + + +def image_repositories_click_option(): + return click.option( + "--image-repositories", + multiple=True, + callback=image_repositories_callback, + type=ImageRepositoriesType(), + required=False, + help="Specify mapping of Function Logical ID to ECR Repo uri, of the form Function_Logical_ID=ECR_Repo_Uri." + "This option can be specified multiple times.", + ) + + +def image_repositories_option(f): + return image_repositories_click_option()(f) + + +def s3_prefix_click_option(): + return click.option( + "--s3-prefix", + required=False, + help="A prefix name that the command adds to the artifacts " + "name when it uploads them to the S3 bucket. The prefix name is a " + "path name (folder name) for the S3 bucket.", + ) + + +def s3_prefix_option(f): + return s3_prefix_click_option()(f) + + +def kms_key_id_click_option(): + return click.option( + "--kms-key-id", + required=False, + help="The ID of an AWS KMS key that the command uses to encrypt artifacts that are at rest in the S3 bucket.", + ) + + +def kms_key_id_option(f): + return kms_key_id_click_option()(f) + + +def use_json_click_option(): + return click.option( + "--use-json", + required=False, + is_flag=True, + help="Indicates whether to use JSON as the format for " + "the output AWS CloudFormation template. YAML is used by default.", + ) + + +def use_json_option(f): + return use_json_click_option()(f) + + +def force_upload_click_option(): + return click.option( + "--force-upload", + required=False, + is_flag=True, + help="Indicates whether to override existing files " + "in the S3 bucket. Specify this flag to upload artifacts even if they " + "match existing artifacts in the S3 bucket.", + ) + + +def force_upload_option(f): + return force_upload_click_option()(f) + + +def resolve_s3_click_option(guided): + from samcli.commands.package.exceptions import PackageResolveS3AndS3SetError, PackageResolveS3AndS3NotSetError + + callback = ( + None + if guided + else partial( + resolve_s3_callback, + artifact=ZIP, + exc_set=PackageResolveS3AndS3SetError, + exc_not_set=PackageResolveS3AndS3NotSetError, + ) + ) + return click.option( + "--resolve-s3", + required=False, + is_flag=True, + callback=callback, + help="Automatically resolve s3 bucket for non-guided deployments. " + "Enabling this option will also create a managed default s3 bucket for you. " + "If you do not provide a --s3-bucket value, the managed bucket will be used. " + "Do not use --s3-guided parameter with this option.", + ) + + +@parameterized_option +def resolve_s3_option(f, guided=False): + return resolve_s3_click_option(guided)(f) + + +def role_arn_click_option(): + return click.option( + "--role-arn", + required=False, + help="The Amazon Resource Name (ARN) of an AWS Identity " + "and Access Management (IAM) role that AWS CloudFormation assumes when " + "executing the change set.", + ) + + +def role_arn_option(f): + return role_arn_click_option()(f) + + +def resolve_image_repos_click_option(): + return click.option( + "--resolve-image-repos", + required=False, + is_flag=True, + help="Automatically create and delete ECR repositories for image-based functions in non-guided deployments. " + "A companion stack containing ECR repos for each function will be deployed along with the template stack. " + "Automatically created image repositories will be deleted if the corresponding functions are removed.", + ) + + +def resolve_image_repos_option(f): + return resolve_image_repos_click_option()(f) + + def _space_separated_list_func_type(value): if isinstance(value, str): return value.split(" ") diff --git a/samcli/commands/_utils/template.py b/samcli/commands/_utils/template.py index 08c02836da..61b1c497e9 100644 --- a/samcli/commands/_utils/template.py +++ b/samcli/commands/_utils/template.py @@ -9,16 +9,16 @@ import yaml from botocore.utils import set_value_from_jmespath -from samcli.commands._utils.resources import ( +from samcli.commands.exceptions import UserException +from samcli.lib.utils.packagetype import ZIP +from samcli.yamlhelper import yaml_parse, yaml_dump +from samcli.lib.utils.resources import ( METADATA_WITH_LOCAL_PATHS, RESOURCES_WITH_LOCAL_PATHS, AWS_SERVERLESS_FUNCTION, AWS_LAMBDA_FUNCTION, get_packageable_resource_paths, ) -from samcli.commands.exceptions import UserException -from samcli.lib.utils.packagetype import ZIP -from samcli.yamlhelper import yaml_parse, yaml_dump class TemplateNotFoundException(UserException): diff --git a/samcli/commands/build/build_context.py b/samcli/commands/build/build_context.py index fbaa02ca35..cd9757b1da 100644 --- a/samcli/commands/build/build_context.py +++ b/samcli/commands/build/build_context.py @@ -8,23 +8,39 @@ import shutil from typing import Dict, Optional, List +import click + from samcli.commands.build.exceptions import InvalidBuildDirException, MissingBuildMethodException +from samcli.lib.bootstrap.nested_stack.nested_stack_manager import NestedStackManager +from samcli.lib.build.build_graph import DEFAULT_DEPENDENCIES_DIR from samcli.lib.intrinsic_resolver.intrinsics_symbol_table import IntrinsicsSymbolTable from samcli.lib.providers.provider import ResourcesToBuildCollector, Stack, Function, LayerVersion from samcli.lib.providers.sam_function_provider import SamFunctionProvider from samcli.lib.providers.sam_layer_provider import SamLayerProvider from samcli.lib.providers.sam_stack_provider import SamLocalStackProvider +from samcli.lib.utils.osutils import BUILD_DIR_PERMISSIONS from samcli.local.docker.manager import ContainerManager from samcli.local.lambdafn.exceptions import ResourceNotFound +from samcli.lib.build.exceptions import BuildInsideContainerError + +from samcli.commands.exceptions import UserException + +from samcli.lib.build.app_builder import ( + ApplicationBuilder, + BuildError, + UnsupportedBuilderLibraryVersionError, + ContainerBuildNotSupported, +) +from samcli.commands._utils.options import DEFAULT_BUILD_DIR +from samcli.lib.build.workflow_config import UnsupportedRuntimeException +from samcli.local.lambdafn.exceptions import FunctionNotFound +from samcli.commands._utils.template import move_template +from samcli.lib.build.exceptions import InvalidBuildGraphException LOG = logging.getLogger(__name__) class BuildContext: - # Build directories need not be world writable. - # This is usually a optimal permission for directories - _BUILD_DIR_PERMISSIONS = 0o755 - def __init__( self, resource_identifier: Optional[str], @@ -33,6 +49,7 @@ def __init__( build_dir: str, cache_dir: str, cached: bool, + parallel: bool, mode: Optional[str], manifest_path: Optional[str] = None, clean: bool = False, @@ -47,6 +64,8 @@ def __init__( container_env_var_file: Optional[str] = None, build_images: Optional[dict] = None, aws_region: Optional[str] = None, + create_auto_dependency_layer: bool = False, + stack_name: Optional[str] = None, ) -> None: self._resource_identifier = resource_identifier @@ -58,6 +77,7 @@ def __init__( self._build_dir = build_dir self._cache_dir = cache_dir + self._parallel = parallel self._manifest_path = manifest_path self._clean = clean self._use_container = use_container @@ -73,6 +93,8 @@ def __init__( self._container_env_var = container_env_var self._container_env_var_file = container_env_var_file self._build_images = build_images + self._create_auto_dependency_layer = create_auto_dependency_layer + self._stack_name = stack_name self._function_provider: Optional[SamFunctionProvider] = None self._layer_provider: Optional[SamLayerProvider] = None @@ -80,7 +102,12 @@ def __init__( self._stacks: List[Stack] = [] def __enter__(self) -> "BuildContext": + self.set_up() + return self + def set_up(self) -> None: + """Set up class members used for building + This should be called each time before run() if stacks are changed.""" self._stacks, remote_stack_full_paths = SamLocalStackProvider.get_stacks( self._template_file, parameter_overrides=self._parameter_overrides, @@ -108,19 +135,130 @@ def __enter__(self) -> "BuildContext": if self._cached: cache_path = pathlib.Path(self._cache_dir) - cache_path.mkdir(mode=self._BUILD_DIR_PERMISSIONS, parents=True, exist_ok=True) + cache_path.mkdir(mode=BUILD_DIR_PERMISSIONS, parents=True, exist_ok=True) self._cache_dir = str(cache_path.resolve()) + dependencies_path = pathlib.Path(DEFAULT_DEPENDENCIES_DIR) + dependencies_path.mkdir(mode=BUILD_DIR_PERMISSIONS, parents=True, exist_ok=True) if self._use_container: self._container_manager = ContainerManager( docker_network_id=self._docker_network, skip_pull_image=self._skip_pull_image ) - return self - def __exit__(self, *args): pass + def get_resources_to_build(self): + return self.resources_to_build + + def run(self): + """Runs the building process by creating an ApplicationBuilder.""" + try: + builder = ApplicationBuilder( + self.get_resources_to_build(), + self.build_dir, + self.base_dir, + self.cache_dir, + self.cached, + self.is_building_specific_resource, + manifest_path_override=self.manifest_path_override, + container_manager=self.container_manager, + mode=self.mode, + parallel=self._parallel, + container_env_var=self._container_env_var, + container_env_var_file=self._container_env_var_file, + build_images=self._build_images, + combine_dependencies=not self._create_auto_dependency_layer, + ) + except FunctionNotFound as ex: + raise UserException(str(ex), wrapped_from=ex.__class__.__name__) from ex + + try: + build_result = builder.build() + artifacts = build_result.artifacts + + stack_output_template_path_by_stack_path = { + stack.stack_path: stack.get_output_template_path(self.build_dir) for stack in self.stacks + } + for stack in self.stacks: + modified_template = builder.update_template( + stack, + artifacts, + stack_output_template_path_by_stack_path, + ) + output_template_path = stack.get_output_template_path(self.build_dir) + + if self._create_auto_dependency_layer: + LOG.debug("Auto creating dependency layer for each function resource into a nested stack") + nested_stack_manager = NestedStackManager( + self._stack_name, self.build_dir, stack.location, modified_template, build_result + ) + modified_template = nested_stack_manager.generate_auto_dependency_layer_stack() + move_template(stack.location, output_template_path, modified_template) + + click.secho("\nBuild Succeeded", fg="green") + + # try to use relpath so the command is easier to understand, however, + # under Windows, when SAM and (build_dir or output_template_path) are + # on different drive, relpath() fails. + root_stack = SamLocalStackProvider.find_root_stack(self.stacks) + out_template_path = root_stack.get_output_template_path(self.build_dir) + try: + build_dir_in_success_message = os.path.relpath(self.build_dir) + output_template_path_in_success_message = os.path.relpath(out_template_path) + except ValueError: + LOG.debug("Failed to retrieve relpath - using the specified path as-is instead") + build_dir_in_success_message = self.build_dir + output_template_path_in_success_message = out_template_path + + msg = self.gen_success_msg( + build_dir_in_success_message, + output_template_path_in_success_message, + os.path.abspath(self.build_dir) == os.path.abspath(DEFAULT_BUILD_DIR), + ) + + click.secho(msg, fg="yellow") + + except ( + UnsupportedRuntimeException, + BuildError, + BuildInsideContainerError, + UnsupportedBuilderLibraryVersionError, + ContainerBuildNotSupported, + InvalidBuildGraphException, + ) as ex: + click.secho("\nBuild Failed", fg="red") + + # Some Exceptions have a deeper wrapped exception that needs to be surfaced + # from deeper than just one level down. + deep_wrap = getattr(ex, "wrapped_from", None) + wrapped_from = deep_wrap if deep_wrap else ex.__class__.__name__ + raise UserException(str(ex), wrapped_from=wrapped_from) from ex + + @staticmethod + def gen_success_msg(artifacts_dir: str, output_template_path: str, is_default_build_dir: bool) -> str: + + invoke_cmd = "sam local invoke" + if not is_default_build_dir: + invoke_cmd += " -t {}".format(output_template_path) + + deploy_cmd = "sam deploy --guided" + if not is_default_build_dir: + deploy_cmd += " --template-file {}".format(output_template_path) + + msg = """\nBuilt Artifacts : {artifacts_dir} +Built Template : {template} + +Commands you can use next +========================= +[*] Invoke Function: {invokecmd} +[*] Deploy: {deploycmd} + """.format( + invokecmd=invoke_cmd, deploycmd=deploy_cmd, artifacts_dir=artifacts_dir, template=output_template_path + ) + + return msg + @staticmethod def _setup_build_dir(build_dir: str, clean: bool) -> str: build_path = pathlib.Path(build_dir) @@ -138,7 +276,7 @@ def _setup_build_dir(build_dir: str, clean: bool) -> str: # build folder contains something inside. Clear everything. shutil.rmtree(build_dir) - build_path.mkdir(mode=BuildContext._BUILD_DIR_PERMISSIONS, parents=True, exist_ok=True) + build_path.mkdir(mode=BUILD_DIR_PERMISSIONS, parents=True, exist_ok=True) # ensure path resolving is done after creation: https://bugs.python.org/issue32434 return str(build_path.resolve()) @@ -205,21 +343,61 @@ def resources_to_build(self) -> ResourcesToBuildCollector: ------- ResourcesToBuildCollector """ + return ( + self.collect_build_resources(self._resource_identifier) + if self._resource_identifier + else self.collect_all_build_resources() + ) + + @property + def create_auto_dependency_layer(self) -> bool: + return self._create_auto_dependency_layer + + def collect_build_resources(self, resource_identifier: str) -> ResourcesToBuildCollector: + """Collect a single buildable resource and its dependencies. + For a Lambda function, its layers will be included. + + Parameters + ---------- + resource_identifier : str + Resource identifier for the resource to be built + + Returns + ------- + ResourcesToBuildCollector + ResourcesToBuildCollector containing the buildable resource and its dependencies + + Raises + ------ + ResourceNotFound + raises ResourceNotFound is the specified resource cannot be found. + """ result = ResourcesToBuildCollector() - if self._resource_identifier: - self._collect_single_function_and_dependent_layers(self._resource_identifier, result) - self._collect_single_buildable_layer(self._resource_identifier, result) + # Get the functions and its layer. Skips if it's inline. + self._collect_single_function_and_dependent_layers(resource_identifier, result) + self._collect_single_buildable_layer(resource_identifier, result) - if not result.functions and not result.layers: - all_resources = [f.name for f in self.function_provider.get_all() if not f.inlinecode] - all_resources.extend([l.name for l in self.layer_provider.get_all()]) + if not result.functions and not result.layers: + # Collect all functions and layers that are not inline + all_resources = [f.name for f in self.function_provider.get_all() if not f.inlinecode] + all_resources.extend([l.name for l in self.layer_provider.get_all()]) - available_resource_message = ( - f"{self._resource_identifier} not found. Possible options in your " f"template: {all_resources}" - ) - LOG.info(available_resource_message) - raise ResourceNotFound(f"Unable to find a function or layer with name '{self._resource_identifier}'") - return result + available_resource_message = ( + f"{resource_identifier} not found. Possible options in your " f"template: {all_resources}" + ) + LOG.info(available_resource_message) + raise ResourceNotFound(f"Unable to find a function or layer with name '{resource_identifier}'") + return result + + def collect_all_build_resources(self) -> ResourcesToBuildCollector: + """Collect all buildable resources. Including Lambda functions and layers. + + Returns + ------- + ResourcesToBuildCollector + ResourcesToBuildCollector that contains all the buildable resources. + """ + result = ResourcesToBuildCollector() result.add_functions([f for f in self.function_provider.get_all() if BuildContext._is_function_buildable(f)]) result.add_layers([l for l in self.layer_provider.get_all() if BuildContext._is_layer_buildable(l)]) return result diff --git a/samcli/commands/build/command.py b/samcli/commands/build/command.py index 5c450d6589..0153a03319 100644 --- a/samcli/commands/build/command.py +++ b/samcli/commands/build/command.py @@ -8,14 +8,18 @@ import click from samcli.cli.context import Context +from samcli.commands._utils.experimental import experimental from samcli.commands._utils.options import ( template_option_without_build, docker_common_options, parameter_override_option, + build_dir_option, + cache_dir_option, + base_dir_option, + manifest_option, + cached_option, ) from samcli.cli.main import pass_context, common_options as cli_framework_options, aws_creds_options, print_cmdline_args -from samcli.lib.build.exceptions import BuildInsideContainerError -from samcli.lib.providers.sam_stack_provider import SamLocalStackProvider from samcli.lib.telemetry.metric import track_command from samcli.cli.cli_config_file import configuration_option, TomlProvider from samcli.lib.utils.version_checker import check_newer_version @@ -24,8 +28,6 @@ LOG = logging.getLogger(__name__) -DEFAULT_BUILD_DIR = os.path.join(".aws-sam", "build") -DEFAULT_CACHE_DIR = os.path.join(".aws-sam", "cache") HELP_TEXT = """ Use this command to build your AWS Lambda Functions source code to generate artifacts that target AWS Lambda's @@ -77,31 +79,6 @@ @click.command("build", help=HELP_TEXT, short_help="Build your Lambda function code") @configuration_option(provider=TomlProvider(section="parameters")) -@click.option( - "--build-dir", - "-b", - default=DEFAULT_BUILD_DIR, - type=click.Path(file_okay=False, dir_okay=True, writable=True), # Must be a directory - help="Path to a folder where the built artifacts will be stored. " - "This directory will be first removed before starting a build.", -) -@click.option( - "--cache-dir", - "-cd", - default=DEFAULT_CACHE_DIR, - type=click.Path(file_okay=False, dir_okay=True, writable=True), # Must be a directory - help="The folder where the cache artifacts will be stored when --cached is specified. " - "The default cache directory is .aws-sam/cache", -) -@click.option( - "--base-dir", - "-s", - default=None, - type=click.Path(dir_okay=True, file_okay=False), # Must be a directory - help="Resolve relative paths to function's source code with respect to this folder. Use this if " - "SAM template and your source code are not in same enclosing folder. By default, relative paths " - "are resolved with respect to the SAM template's location", -) @click.option( "--use-container", "-u", @@ -150,28 +127,15 @@ help="Enabled parallel builds. Use this flag to build your AWS SAM template's functions and layers in parallel. " "By default the functions and layers are built in sequence", ) -@click.option( - "--manifest", - "-m", - default=None, - type=click.Path(), - help="Path to a custom dependency manifest (e.g., package.json) to use instead of the default one", -) -@click.option( - "--cached", - "-c", - is_flag=True, - help="Enable cached builds. Use this flag to reuse build artifacts that have not changed from previous builds. " - "AWS SAM evaluates whether you have made any changes to files in your project directory. \n\n" - "Note: AWS SAM does not evaluate whether changes have been made to third party modules " - "that your project depends on, where you have not provided a specific version. " - "For example, if your Python function includes a requirements.txt file with the following entry " - "requests=1.x and the latest request module version changes from 1.1 to 1.2, " - "SAM will not pull the latest version until you run a non-cached build.", -) +@build_dir_option +@cache_dir_option +@base_dir_option +@manifest_option +@cached_option @template_option_without_build @parameter_override_option @docker_common_options +@experimental @cli_framework_options @aws_creds_options @click.argument("resource_logical_id", required=False) @@ -253,19 +217,7 @@ def do_cli( # pylint: disable=too-many-locals, too-many-statements Implementation of the ``cli`` method """ - from samcli.commands.exceptions import UserException - from samcli.commands.build.build_context import BuildContext - from samcli.lib.build.app_builder import ( - ApplicationBuilder, - BuildError, - UnsupportedBuilderLibraryVersionError, - ContainerBuildNotSupported, - ) - from samcli.lib.build.workflow_config import UnsupportedRuntimeException - from samcli.local.lambdafn.exceptions import FunctionNotFound - from samcli.commands._utils.template import move_template - from samcli.lib.build.build_graph import InvalidBuildGraphException LOG.debug("'build' command is called") if cached: @@ -283,6 +235,7 @@ def do_cli( # pylint: disable=too-many-locals, too-many-statements build_dir, cache_dir, cached, + parallel=parallel, clean=clean, manifest_path=manifest_path, use_container=use_container, @@ -295,100 +248,7 @@ def do_cli( # pylint: disable=too-many-locals, too-many-statements build_images=processed_build_images, aws_region=click_ctx.region, ) as ctx: - try: - builder = ApplicationBuilder( - ctx.resources_to_build, - ctx.build_dir, - ctx.base_dir, - ctx.cache_dir, - ctx.cached, - ctx.is_building_specific_resource, - manifest_path_override=ctx.manifest_path_override, - container_manager=ctx.container_manager, - mode=ctx.mode, - parallel=parallel, - container_env_var=processed_env_vars, - container_env_var_file=container_env_var_file, - build_images=processed_build_images, - ) - except FunctionNotFound as ex: - raise UserException(str(ex), wrapped_from=ex.__class__.__name__) from ex - - try: - artifacts = builder.build() - stack_output_template_path_by_stack_path = { - stack.stack_path: stack.get_output_template_path(ctx.build_dir) for stack in ctx.stacks - } - for stack in ctx.stacks: - modified_template = builder.update_template( - stack, - artifacts, - stack_output_template_path_by_stack_path, - ) - move_template(stack.location, stack.get_output_template_path(ctx.build_dir), modified_template) - - click.secho("\nBuild Succeeded", fg="green") - - # try to use relpath so the command is easier to understand, however, - # under Windows, when SAM and (build_dir or output_template_path) are - # on different drive, relpath() fails. - root_stack = SamLocalStackProvider.find_root_stack(ctx.stacks) - out_template_path = root_stack.get_output_template_path(ctx.build_dir) - try: - build_dir_in_success_message = os.path.relpath(ctx.build_dir) - output_template_path_in_success_message = os.path.relpath(out_template_path) - except ValueError: - LOG.debug("Failed to retrieve relpath - using the specified path as-is instead") - build_dir_in_success_message = ctx.build_dir - output_template_path_in_success_message = out_template_path - - msg = gen_success_msg( - build_dir_in_success_message, - output_template_path_in_success_message, - os.path.abspath(ctx.build_dir) == os.path.abspath(DEFAULT_BUILD_DIR), - ) - - click.secho(msg, fg="yellow") - - except ( - UnsupportedRuntimeException, - BuildError, - BuildInsideContainerError, - UnsupportedBuilderLibraryVersionError, - ContainerBuildNotSupported, - InvalidBuildGraphException, - ) as ex: - click.secho("\nBuild Failed", fg="red") - - # Some Exceptions have a deeper wrapped exception that needs to be surfaced - # from deeper than just one level down. - deep_wrap = getattr(ex, "wrapped_from", None) - wrapped_from = deep_wrap if deep_wrap else ex.__class__.__name__ - raise UserException(str(ex), wrapped_from=wrapped_from) from ex - - -def gen_success_msg(artifacts_dir: str, output_template_path: str, is_default_build_dir: bool) -> str: - - invoke_cmd = "sam local invoke" - if not is_default_build_dir: - invoke_cmd += " -t {}".format(output_template_path) - - deploy_cmd = "sam deploy --guided" - if not is_default_build_dir: - deploy_cmd += " --template-file {}".format(output_template_path) - - msg = """\nBuilt Artifacts : {artifacts_dir} -Built Template : {template} - -Commands you can use next -========================= -[*] Invoke Function: {invokecmd} -[*] Deploy: {deploycmd} - """.format( - invokecmd=invoke_cmd, deploycmd=deploy_cmd, artifacts_dir=artifacts_dir, template=output_template_path - ) - - return msg + ctx.run() def _get_mode_value_from_envvar(name: str, choices: List[str]) -> Optional[str]: diff --git a/samcli/commands/delete/delete_context.py b/samcli/commands/delete/delete_context.py index ad29ce9c04..f228580fd1 100644 --- a/samcli/commands/delete/delete_context.py +++ b/samcli/commands/delete/delete_context.py @@ -12,7 +12,7 @@ from click import prompt from samcli.cli.cli_config_file import TomlProvider -from samcli.lib.utils.botoconfig import get_boto_config_with_user_agent +from samcli.lib.utils.boto_utils import get_boto_config_with_user_agent from samcli.lib.delete.cfn_utils import CfnUtils from samcli.lib.package.s3_uploader import S3Uploader diff --git a/samcli/commands/deploy/command.py b/samcli/commands/deploy/command.py index 682ec87938..784b79b506 100644 --- a/samcli/commands/deploy/command.py +++ b/samcli/commands/deploy/command.py @@ -7,18 +7,26 @@ from samcli.cli.cli_config_file import TomlProvider, configuration_option from samcli.cli.main import aws_creds_options, common_options, pass_context, print_cmdline_args -from samcli.cli.types import ImageRepositoryType, ImageRepositoriesType from samcli.commands._utils.options import ( - capabilities_override_option, - guided_deploy_stack_name, - metadata_override_option, - notification_arns_override_option, + capabilities_option, + metadata_option, + notification_arns_option, parameter_override_option, no_progressbar_option, - tags_override_option, + tags_option, template_click_option, signing_profiles_option, - image_repositories_callback, + stack_name_option, + s3_bucket_option, + image_repository_option, + image_repositories_option, + s3_prefix_option, + kms_key_id_option, + use_json_option, + force_upload_option, + resolve_s3_option, + role_arn_option, + resolve_image_repos_option, ) from samcli.commands.deploy.utils import sanitize_parameter_overrides from samcli.lib.telemetry.metric import track_command @@ -59,56 +67,6 @@ help="Specify this flag to allow SAM CLI to guide you through the deployment using guided prompts.", ) @template_click_option(include_build=True) -@click.option( - "--stack-name", - required=False, - callback=guided_deploy_stack_name, - help="The name of the AWS CloudFormation stack you're deploying to. " - "If you specify an existing stack, the command updates the stack. " - "If you specify a new stack, the command creates it.", -) -@click.option( - "--s3-bucket", - required=False, - help="The name of the S3 bucket where this command uploads your " - "CloudFormation template. This is required the deployments of " - "templates sized greater than 51,200 bytes", -) -@click.option( - "--image-repository", - type=ImageRepositoryType(), - required=False, - help="ECR repo uri where this command uploads the image artifacts that are referenced in your template.", -) -@click.option( - "--image-repositories", - multiple=True, - callback=image_repositories_callback, - type=ImageRepositoriesType(), - required=False, - help="Specify mapping of Function Logical ID to ECR Repo uri, of the form Function_Logical_ID=ECR_Repo_Uri." - "This option can be specified multiple times.", -) -@click.option( - "--force-upload", - required=False, - is_flag=True, - help="Indicates whether to override existing files in the S3 bucket. " - "Specify this flag to upload artifacts even if they " - "match existing artifacts in the S3 bucket.", -) -@click.option( - "--s3-prefix", - required=False, - help="A prefix name that the command adds to the " - "artifacts' name when it uploads them to the S3 bucket. " - "The prefix name is a path name (folder name) for the S3 bucket.", -) -@click.option( - "--kms-key-id", - required=False, - help="The ID of an AWS KMS key that the command uses to encrypt artifacts that are at rest in the S3 bucket.", -) @click.option( "--no-execute-changeset", required=False, @@ -120,13 +78,6 @@ "the changeset looks satisfactory, the stack changes can be made by " "running the same command without specifying `--no-execute-changeset`", ) -@click.option( - "--role-arn", - required=False, - help="The Amazon Resource Name (ARN) of an AWS Identity " - "and Access Management (IAM) role that AWS CloudFormation assumes when " - "executing the change set.", -) @click.option( "--fail-on-empty-changeset/--no-fail-on-empty-changeset", default=True, @@ -143,30 +94,6 @@ is_flag=True, help="Prompt to confirm if the computed changeset is to be deployed by SAM CLI.", ) -@click.option( - "--use-json", - required=False, - is_flag=True, - help="Indicates whether to use JSON as the format for " - "the output AWS CloudFormation template. YAML is used by default.", -) -@click.option( - "--resolve-s3", - required=False, - is_flag=True, - help="Automatically resolve s3 bucket for non-guided deployments. " - "Enabling this option will also create a managed default s3 bucket for you. " - "If you do not provide a --s3-bucket value, the managed bucket will be used. " - "Do not use --s3-guided parameter with this option.", -) -@click.option( - "--resolve-image-repos", - required=False, - is_flag=True, - help="Automatically create and delete ECR repositories for image-based functions in non-guided deployments. " - "A companion stack containing ECR repos for each function will be deployed along with the template stack. " - "Automatically created image repositories will be deleted if the corresponding functions are removed.", -) @click.option( "--disable-rollback/--no-disable-rollback", default=False, @@ -174,13 +101,24 @@ is_flag=True, help="Preserves the state of previously provisioned resources when an operation fails.", ) -@metadata_override_option -@notification_arns_override_option -@tags_override_option +@stack_name_option +@s3_bucket_option(guided=True) # pylint: disable=E1120 +@image_repository_option +@image_repositories_option +@force_upload_option +@s3_prefix_option +@kms_key_id_option +@role_arn_option +@use_json_option +@resolve_s3_option(guided=True) # pylint: disable=E1120 +@resolve_image_repos_option +@metadata_option +@notification_arns_option +@tags_option @parameter_override_option @signing_profiles_option @no_progressbar_option -@capabilities_override_option +@capabilities_option @aws_creds_options @common_options @image_repository_validation @@ -372,6 +310,7 @@ def do_cli( profile=profile, confirm_changeset=guided_context.confirm_changeset if guided else confirm_changeset, signing_profiles=guided_context.signing_profiles if guided else signing_profiles, + use_changeset=True, disable_rollback=guided_context.disable_rollback if guided else disable_rollback, ) as deploy_context: deploy_context.run() diff --git a/samcli/commands/deploy/deploy_context.py b/samcli/commands/deploy/deploy_context.py index 5ed3d83eb4..554ff00de6 100644 --- a/samcli/commands/deploy/deploy_context.py +++ b/samcli/commands/deploy/deploy_context.py @@ -33,7 +33,7 @@ from samcli.lib.intrinsic_resolver.intrinsics_symbol_table import IntrinsicsSymbolTable from samcli.lib.package.s3_uploader import S3Uploader from samcli.lib.providers.sam_stack_provider import SamLocalStackProvider -from samcli.lib.utils.botoconfig import get_boto_config_with_user_agent +from samcli.lib.utils.boto_utils import get_boto_config_with_user_agent from samcli.yamlhelper import yaml_parse LOG = logging.getLogger(__name__) @@ -70,6 +70,7 @@ def __init__( profile, confirm_changeset, signing_profiles, + use_changeset, disable_rollback, ): self.template_file = template_file @@ -98,6 +99,7 @@ def __init__( self.deployer = None self.confirm_changeset = confirm_changeset self.signing_profiles = signing_profiles + self.use_changeset = use_changeset self.disable_rollback = disable_rollback def __enter__(self): @@ -153,6 +155,7 @@ def run(self): display_parameter_overrides, self.confirm_changeset, self.signing_profiles, + self.use_changeset, self.disable_rollback, ) return self.deploy( @@ -168,6 +171,7 @@ def run(self): region, self.fail_on_empty_changeset, self.confirm_changeset, + self.use_changeset, self.disable_rollback, ) @@ -185,6 +189,7 @@ def deploy( region, fail_on_empty_changeset=True, confirm_changeset=False, + use_changeset=True, disable_rollback=False, ): """ @@ -218,6 +223,8 @@ def deploy( Should fail when changeset is empty confirm_changeset : bool Should wait for customer's confirm before executing the changeset + use_changeset : bool + Involve creation of changesets, false when using sam sync disable_rollback : bool Preserves the state of previously provisioned resources when an operation fails """ @@ -232,36 +239,55 @@ def deploy( if not authorization_required: click.secho(f"{resource} may not have authorization defined.", fg="yellow") - try: - result, changeset_type = self.deployer.create_and_wait_for_changeset( - stack_name=stack_name, - cfn_template=template_str, - parameter_values=parameters, - capabilities=capabilities, - role_arn=role_arn, - notification_arns=notification_arns, - s3_uploader=s3_uploader, - tags=tags, - ) - click.echo(self.MSG_SHOWCASE_CHANGESET.format(changeset_id=result["Id"])) - - if no_execute_changeset: - return - - if confirm_changeset: - click.secho(self.MSG_CONFIRM_CHANGESET_HEADER, fg="yellow") - click.secho("=" * len(self.MSG_CONFIRM_CHANGESET_HEADER), fg="yellow") - if not click.confirm(f"{self.MSG_CONFIRM_CHANGESET}", default=False): + if use_changeset: + try: + result, changeset_type = self.deployer.create_and_wait_for_changeset( + stack_name=stack_name, + cfn_template=template_str, + parameter_values=parameters, + capabilities=capabilities, + role_arn=role_arn, + notification_arns=notification_arns, + s3_uploader=s3_uploader, + tags=tags, + ) + click.echo(self.MSG_SHOWCASE_CHANGESET.format(changeset_id=result["Id"])) + + if no_execute_changeset: return - self.deployer.execute_changeset(result["Id"], stack_name, disable_rollback) - self.deployer.wait_for_execute(stack_name, changeset_type, disable_rollback) - click.echo(self.MSG_EXECUTE_SUCCESS.format(stack_name=stack_name, region=region)) - - except deploy_exceptions.ChangeEmptyError as ex: - if fail_on_empty_changeset: + if confirm_changeset: + click.secho(self.MSG_CONFIRM_CHANGESET_HEADER, fg="yellow") + click.secho("=" * len(self.MSG_CONFIRM_CHANGESET_HEADER), fg="yellow") + if not click.confirm(f"{self.MSG_CONFIRM_CHANGESET}", default=False): + return + + self.deployer.execute_changeset(result["Id"], stack_name, disable_rollback) + self.deployer.wait_for_execute(stack_name, changeset_type, disable_rollback) + click.echo(self.MSG_EXECUTE_SUCCESS.format(stack_name=stack_name, region=region)) + + except deploy_exceptions.ChangeEmptyError as ex: + if fail_on_empty_changeset: + raise + LOG.error(str(ex)) + + else: + try: + result = self.deployer.sync( + stack_name=stack_name, + cfn_template=template_str, + parameter_values=parameters, + capabilities=capabilities, + role_arn=role_arn, + notification_arns=notification_arns, + s3_uploader=s3_uploader, + tags=tags, + ) + LOG.info(result) + + except deploy_exceptions.DeployFailedError as ex: + LOG.error(str(ex)) raise - click.echo(str(ex)) @staticmethod def merge_parameters(template_dict: Dict, parameter_overrides: Dict) -> List[Dict]: diff --git a/samcli/commands/deploy/utils.py b/samcli/commands/deploy/utils.py index 647572d8b1..7a424d7181 100644 --- a/samcli/commands/deploy/utils.py +++ b/samcli/commands/deploy/utils.py @@ -18,6 +18,7 @@ def print_deploy_args( parameter_overrides, confirm_changeset, signing_profiles, + use_changeset, disable_rollback, ): """ @@ -45,6 +46,7 @@ def print_deploy_args( :param parameter_overrides: Cloudformation parameter overrides to be supplied based on the stack's template :param confirm_changeset: Prompt for changeset to be confirmed before going ahead with the deploy. :param signing_profiles: Signing profile details which will be used to sign functions/layers + :param use_changeset: Flag to use or skip the usage of changesets :param disable_rollback: Preserve the state of previously provisioned resources when an operation fails. """ _parameters = parameter_overrides.copy() @@ -65,7 +67,8 @@ def print_deploy_args( click.secho("\n\tDeploying with following values\n\t===============================", fg="yellow") click.echo(f"\tStack name : {stack_name}") click.echo(f"\tRegion : {region}") - click.echo(f"\tConfirm changeset : {confirm_changeset}") + if use_changeset: + click.echo(f"\tConfirm changeset : {confirm_changeset}") click.echo(f"\tDisable rollback : {disable_rollback}") if image_repository: msg = "Deployment image repository : " diff --git a/samcli/commands/init/init_templates.py b/samcli/commands/init/init_templates.py index 7b85ed3d26..a498e88c82 100644 --- a/samcli/commands/init/init_templates.py +++ b/samcli/commands/init/init_templates.py @@ -10,8 +10,8 @@ from typing import Dict import click +from samcli.cli.global_config import GlobalConfig -from samcli.cli.main import global_cfg from samcli.commands.exceptions import UserException, AppTemplateUpdateException from samcli.lib.utils.git_repo import GitRepo, CloneRepoException, CloneRepoUnstableStateException from samcli.lib.utils.packagetype import IMAGE @@ -105,7 +105,7 @@ def _check_app_template(entry: Dict, app_template: str) -> bool: def init_options(self, package_type, runtime, base_image, dependency_manager): if not self._git_repo.clone_attempted: - shared_dir: Path = global_cfg.config_dir + shared_dir: Path = GlobalConfig().config_dir try: self._git_repo.clone(clone_dir=shared_dir, clone_name=APP_TEMPLATES_REPO_NAME, replace_existing=True) except CloneRepoUnstableStateException as ex: diff --git a/samcli/commands/logs/command.py b/samcli/commands/logs/command.py index 7042970a3a..8d374544fc 100644 --- a/samcli/commands/logs/command.py +++ b/samcli/commands/logs/command.py @@ -3,12 +3,20 @@ """ import logging + import click +from samcli.cli.cli_config_file import configuration_option, TomlProvider from samcli.cli.main import pass_context, common_options as cli_framework_options, aws_creds_options, print_cmdline_args +from samcli.commands._utils.options import common_observability_options from samcli.lib.telemetry.metric import track_command -from samcli.cli.cli_config_file import configuration_option, TomlProvider from samcli.lib.utils.version_checker import check_newer_version +from samcli.commands._utils.experimental import ( + ExperimentalFlag, + force_experimental_option, + experimental, + prompt_experimental, +) LOG = logging.getLogger(__name__) @@ -38,9 +46,12 @@ @click.option( "--name", "-n", - required=True, - help="Name of your AWS Lambda function. If this function is a part of a CloudFormation stack, " - "this can be the LogicalID of function resource in the CloudFormation/SAM template.", + multiple=True, + help="Name(s) of your AWS Lambda function. If this function is a part of a CloudFormation stack, " + "this can be the LogicalID of function resource in the CloudFormation/SAM template. " + "[Beta Feature] Multiple names can be provided by repeating the parameter again. " + "If it is not provided and no --cw-log-group have been given, it will scan " + "given stack and find all possible resources, and start pulling log information from them.", ) @click.option("--stack-name", default=None, help="Name of the AWS CloudFormation stack that the function is a part of.") @click.option( @@ -52,40 +63,41 @@ "https://docs.aws.amazon.com/AmazonCloudWatch/latest/logs/FilterAndPatternSyntax.html", ) @click.option( - "--start-time", - "-s", - default="10m ago", - help="Fetch logs starting at this time. Time can be relative values like '5mins ago', 'yesterday' or " - "formatted timestamp like '2018-01-01 10:10:10'. Defaults to '10mins ago'.", -) -@click.option( - "--end-time", - "-e", - default=None, - help="Fetch logs up to this time. Time can be relative values like '5mins ago', 'tomorrow' or " - "formatted timestamp like '2018-01-01 10:10:10'", + "--include-traces", + "-i", + is_flag=True, + help="[Beta Feature] Include the XRay traces in the log output.", ) @click.option( - "--tail", - "-t", - is_flag=True, - help="Tail the log output. This will ignore the end time argument and continue to fetch logs as they " - "become available.", + "--cw-log-group", + multiple=True, + help="[Beta Feature] " + "Additional CloudWatch Log group names that are not auto-discovered based upon --name parameter. " + "When provided, it will only tail the given CloudWatch Log groups. If you want to tail log groups related " + "to resources, please also provide their names as well", ) +@common_observability_options +@experimental @cli_framework_options @aws_creds_options @pass_context @track_command @check_newer_version @print_cmdline_args +@force_experimental_option("include_traces", config_entry=ExperimentalFlag.Accelerate) # pylint: disable=E1120 +@force_experimental_option("cw_log_group", config_entry=ExperimentalFlag.Accelerate) # pylint: disable=E1120 +@force_experimental_option("unformatted", config_entry=ExperimentalFlag.Accelerate) # pylint: disable=E1120 def cli( ctx, name, stack_name, filter, tail, + include_traces, start_time, end_time, + unformatted, + cw_log_group, config_file, config_env, ): # pylint: disable=redefined-builtin @@ -94,30 +106,72 @@ def cli( """ # All logic must be implemented in the ``do_cli`` method. This helps with easy unit testing - do_cli(name, stack_name, filter, tail, start_time, end_time) # pragma: no cover + do_cli( + name, + stack_name, + filter, + tail, + include_traces, + start_time, + end_time, + cw_log_group, + unformatted, + ctx.region, + ctx.profile, + ) # pragma: no cover -def do_cli(function_name, stack_name, filter_pattern, tailing, start_time, end_time): +def do_cli( + names, + stack_name, + filter_pattern, + tailing, + include_tracing, + start_time, + end_time, + cw_log_groups, + unformatted, + region, + profile, +): """ Implementation of the ``cli`` method """ - from .logs_context import LogsCommandContext - - LOG.debug("'logs' command is called") - - with LogsCommandContext( - function_name, - stack_name=stack_name, - filter_pattern=filter_pattern, - start_time=start_time, - end_time=end_time, - ) as context: - - if tailing: - context.fetcher.tail(start_time=context.start_time, filter_pattern=context.filter_pattern) - else: - context.fetcher.load_time_period( - start_time=context.start_time, - end_time=context.end_time, - filter_pattern=context.filter_pattern, - ) + + from datetime import datetime + + from samcli.commands.logs.logs_context import parse_time, ResourcePhysicalIdResolver + from samcli.commands.logs.puller_factory import generate_puller + from samcli.lib.utils.boto_utils import get_boto_client_provider_with_config, get_boto_resource_provider_with_config + + if not names or len(names) > 1: + if not prompt_experimental(ExperimentalFlag.Accelerate): + return + else: + click.echo( + "You can now use 'sam logs' without --name parameter, " + "which will pull the logs from all possible resources in your stack." + ) + + sanitized_start_time = parse_time(start_time, "start-time") + sanitized_end_time = parse_time(end_time, "end-time") or datetime.utcnow() + + boto_client_provider = get_boto_client_provider_with_config(region=region, profile=profile) + boto_resource_provider = get_boto_resource_provider_with_config(region=region, profile=profile) + resource_logical_id_resolver = ResourcePhysicalIdResolver(boto_resource_provider, stack_name, names) + + # only fetch all resources when no CloudWatch log group defined + fetch_all_when_no_resource_name_given = not cw_log_groups + puller = generate_puller( + boto_client_provider, + resource_logical_id_resolver.get_resource_information(fetch_all_when_no_resource_name_given), + filter_pattern, + cw_log_groups, + unformatted, + include_tracing, + ) + + if tailing: + puller.tail(sanitized_start_time, filter_pattern) + else: + puller.load_time_period(sanitized_start_time, sanitized_end_time, filter_pattern) diff --git a/samcli/commands/logs/console_consumers.py b/samcli/commands/logs/console_consumers.py index 2f77e34ab0..9881e11725 100644 --- a/samcli/commands/logs/console_consumers.py +++ b/samcli/commands/logs/console_consumers.py @@ -13,6 +13,16 @@ class CWConsoleEventConsumer(ObservabilityEventConsumer[CWLogEvent]): Consumer implementation that will consume given event as outputting into console """ - # pylint: disable=R0201 + def __init__(self, add_newline: bool = False): + """ + + Parameters + ---------- + add_newline : bool + If it is True, it will add a new line at the end of each echo operation. Otherwise it will always print + into same line when echo is called. + """ + self._add_newline = add_newline + def consume(self, event: CWLogEvent): - click.echo(event.message, nl=False) + click.echo(event.message, nl=self._add_newline) diff --git a/samcli/commands/logs/logs_context.py b/samcli/commands/logs/logs_context.py index 5504895a70..777a7791a2 100644 --- a/samcli/commands/logs/logs_context.py +++ b/samcli/commands/logs/logs_context.py @@ -3,274 +3,138 @@ """ import logging +from typing import List, Optional, Set, Any -import boto3 -import botocore - -from samcli.commands.exceptions import UserException -from samcli.commands.logs.console_consumers import CWConsoleEventConsumer -from samcli.lib.observability.cw_logs.cw_log_formatters import ( - CWColorizeErrorsFormatter, - CWJsonFormatter, - CWKeywordHighlighterFormatter, - CWPrettyPrintFormatter, +from samcli.lib.utils.resources import ( + AWS_LAMBDA_FUNCTION, + AWS_APIGATEWAY_RESTAPI, + AWS_APIGATEWAY_V2_API, + AWS_STEPFUNCTIONS_STATEMACHINE, ) -from samcli.lib.observability.cw_logs.cw_log_group_provider import LogGroupProvider -from samcli.lib.observability.cw_logs.cw_log_puller import CWLogPuller -from samcli.lib.observability.observability_info_puller import ObservabilityEventConsumerDecorator -from samcli.lib.utils.colors import Colored +from samcli.commands.exceptions import UserException +from samcli.lib.utils.boto_utils import BotoProviderType +from samcli.lib.utils.cloudformation import get_resource_summaries from samcli.lib.utils.time import to_utc, parse_date LOG = logging.getLogger(__name__) class InvalidTimestampError(UserException): - pass - - -class LogsCommandContext: """ - Sets up a context to run the Logs command by parsing the CLI arguments and creating necessary objects to be able - to fetch and display logs - - This class **must** be used inside a ``with`` statement as follows: - - with LogsCommandContext(**kwargs) as context: - context.fetcher.fetch(...) + Used to indicate that given date time string is an invalid timestamp """ - def __init__( - self, function_name, stack_name=None, filter_pattern=None, start_time=None, end_time=None, output_file=None - ): - """ - Initializes the context - - Parameters - ---------- - function_name : str - Name of the function to fetch logs for - - stack_name : str - Name of the stack where the function is available - - filter_pattern : str - Optional pattern to filter the logs by - - start_time : str - Fetch logs starting at this time - - end_time : str - Fetch logs up to this time - - output_file : str - Write logs to this file instead of Terminal - """ - - self._function_name = function_name - self._stack_name = stack_name - self._filter_pattern = filter_pattern - self._start_time = start_time - self._end_time = end_time - self._output_file = output_file - self._output_file_handle = None - - # No colors when we write to a file. Otherwise use colors - self._must_print_colors = not self._output_file - - self._logs_client = boto3.client("logs") - self._cfn_client = boto3.client("cloudformation") - - def __enter__(self): - """ - Performs some basic checks and returns itself when everything is ready to invoke a Lambda function. - - Returns - ------- - LogsCommandContext - Returns this object - """ - - self._output_file_handle = self._setup_output_file(self._output_file) - return self - - def __exit__(self, *args): - """ - Cleanup any necessary opened files - """ - - if self._output_file_handle: - self._output_file_handle.close() - self._output_file_handle = None - - @property - def fetcher(self): - return CWLogPuller( - logs_client=self._logs_client, - consumer=ObservabilityEventConsumerDecorator( - mappers=[ - CWColorizeErrorsFormatter(self.colored), - CWJsonFormatter(), - CWKeywordHighlighterFormatter(self.colored, self._filter_pattern), - CWPrettyPrintFormatter(self.colored), - ], - consumer=CWConsoleEventConsumer(), - ), - cw_log_group=self.log_group_name, - resource_name=self._function_name, - ) - - @property - def start_time(self): - return self._parse_time(self._start_time, "start-time") - - @property - def end_time(self): - return self._parse_time(self._end_time, "end-time") - - @property - def log_group_name(self): - """ - Name of the AWS CloudWatch Log Group that we will be querying. It generates the name based on the - Lambda Function name and stack name provided. - - Returns - ------- - str - Name of the CloudWatch Log Group - """ +def parse_time(time_str: str, property_name: str): + """ + Parse the time from the given string, convert to UTC, and return the datetime object - function_id = self._function_name - if self._stack_name: - function_id = self._get_resource_id_from_stack(self._cfn_client, self._stack_name, self._function_name) - LOG.debug( - "Function with LogicalId '%s' in stack '%s' resolves to actual physical ID '%s'", - self._function_name, - self._stack_name, - function_id, - ) + Parameters + ---------- + time_str : str + The time to parse - return LogGroupProvider.for_lambda_function(function_id) + property_name : str + Name of the property where this time came from. Used in the exception raised if time is not parseable - @property - def colored(self): - """ - Instance of Colored object to colorize strings + Returns + ------- + datetime.datetime + Parsed datetime object - Returns - ------- - samcli.commands.utils.colors.Colored - """ - # No colors if we are writing output to a file - return Colored(colorize=self._must_print_colors) + Raises + ------ + InvalidTimestampError + If the string cannot be parsed as a timestamp + """ + if not time_str: + return None - @property - def filter_pattern(self): - return self._filter_pattern + parsed = parse_date(time_str) + if not parsed: + raise InvalidTimestampError("Unable to parse the time provided by '{}'".format(property_name)) - @property - def output_file_handle(self): - return self._output_file_handle + return to_utc(parsed) - @staticmethod - def _setup_output_file(output_file): - """ - Open a log file if necessary and return the file handle. This will create a file if it does not exist - Parameters - ---------- - output_file : str - Path to a file where the logs should be written to +class ResourcePhysicalIdResolver: + """ + Wrapper class that is used to extract information about resources which we can tail their logs for given stack + """ - Returns - ------- - Handle to the opened log file, if necessary. None otherwise - """ - if not output_file: - return None + # list of resource types that is supported right now for pulling their logs + DEFAULT_SUPPORTED_RESOURCES: Set[str] = { + AWS_LAMBDA_FUNCTION, + AWS_APIGATEWAY_RESTAPI, + AWS_APIGATEWAY_V2_API, + AWS_STEPFUNCTIONS_STATEMACHINE, + } - return open(output_file, "wb") + def __init__( + self, + boto_resource_provider: BotoProviderType, + stack_name: str, + resource_names: Optional[List[str]] = None, + supported_resource_types: Optional[Set[str]] = None, + ): + self._boto_resource_provider = boto_resource_provider + self._stack_name = stack_name + if resource_names is None: + resource_names = [] + if supported_resource_types is None: + supported_resource_types = ResourcePhysicalIdResolver.DEFAULT_SUPPORTED_RESOURCES + self._supported_resource_types: Set[str] = supported_resource_types + self._resource_names = set(resource_names) - @staticmethod - def _parse_time(time_str, property_name): + def get_resource_information(self, fetch_all_when_no_resource_name_given: bool = True) -> List[Any]: """ - Parse the time from the given string, convert to UTC, and return the datetime object + Returns the list of resource information for the given stack. Parameters ---------- - time_str : str - The time to parse - - property_name : str - Name of the property where this time came from. Used in the exception raised if time is not parseable + fetch_all_when_no_resource_name_given : bool + When given, it will fetch all resources if no specific resource name is provided, default value is True Returns ------- - datetime.datetime - Parsed datetime object - - Raises - ------ - samcli.commands.exceptions.UserException - If the string cannot be parsed as a timestamp + List[StackResourceSummary] + List of resource information, which will be used to fetch the logs """ - if not time_str: - return None + if self._resource_names: + return self._fetch_resources_from_stack(self._resource_names) + if fetch_all_when_no_resource_name_given: + return self._fetch_resources_from_stack() + return [] - parsed = parse_date(time_str) - if not parsed: - raise InvalidTimestampError("Unable to parse the time provided by '{}'".format(property_name)) - - return to_utc(parsed) - - @staticmethod - def _get_resource_id_from_stack(cfn_client, stack_name, logical_id): + def _fetch_resources_from_stack(self, selected_resource_names: Optional[Set[str]] = None) -> List[Any]: """ - Given the LogicalID of a resource, call AWS CloudFormation to get physical ID of the resource within - the specified stack. + Returns list of all resources from given stack name + If any resource is not supported, it will discard them Parameters ---------- - cfn_client : boto3.session.Session.client - CloudFormation client provided by AWS SDK - - stack_name : str - Name of the stack to query - - logical_id : str - LogicalId of the resource + selected_resource_names : Optional[Set[str]] + An optional set of string parameter, which will filter resource names. If none is given, it will be + equal to all resource names in stack, which means there won't be any filtering by resource name. Returns ------- - str - Physical ID of the resource - - Raises - ------ - samcli.commands.exceptions.UserException - If the stack or resource does not exist + List[StackResourceSummary] + List of resource information, which will be used to fetch the logs """ - - LOG.debug( - "Getting resource's PhysicalId from AWS CloudFormation stack. StackName=%s, LogicalId=%s", - stack_name, - logical_id, + results = [] + LOG.debug("Getting logical id of the all resources for stack '%s'", self._stack_name) + stack_resources = get_resource_summaries( + self._boto_resource_provider, self._stack_name, ResourcePhysicalIdResolver.DEFAULT_SUPPORTED_RESOURCES ) - try: - response = cfn_client.describe_stack_resource(StackName=stack_name, LogicalResourceId=logical_id) - - LOG.debug("Response from AWS CloudFormation %s", response) - return response["StackResourceDetail"]["PhysicalResourceId"] - - except botocore.exceptions.ClientError as ex: - LOG.debug( - "Unable to fetch resource name from CloudFormation Stack: " - "StackName=%s, ResourceLogicalId=%s, Response=%s", - stack_name, - logical_id, - ex.response, - ) + if selected_resource_names is None: + selected_resource_names = {stack_resource.logical_resource_id for stack_resource in stack_resources} - # The exception message already has a well formatted error message that we can surface to user - raise UserException(str(ex), wrapped_from=ex.response["Error"]["Code"]) from ex + for resource in stack_resources: + # if resource name is not selected, continue + if resource.logical_resource_id not in selected_resource_names: + LOG.debug("Resource (%s) is not selected with given input", resource.logical_resource_id) + continue + results.append(resource) + return results diff --git a/samcli/commands/logs/puller_factory.py b/samcli/commands/logs/puller_factory.py new file mode 100644 index 0000000000..1926ce5562 --- /dev/null +++ b/samcli/commands/logs/puller_factory.py @@ -0,0 +1,178 @@ +""" +File keeps Factory method to prepare required puller information +with its producers and consumers +""" +import logging +from typing import List, Optional + +from samcli.commands.exceptions import UserException +from samcli.commands.logs.console_consumers import CWConsoleEventConsumer +from samcli.commands.traces.traces_puller_factory import generate_trace_puller +from samcli.lib.observability.cw_logs.cw_log_formatters import ( + CWColorizeErrorsFormatter, + CWJsonFormatter, + CWKeywordHighlighterFormatter, + CWPrettyPrintFormatter, + CWAddNewLineIfItDoesntExist, + CWLogEventJSONMapper, +) +from samcli.lib.observability.cw_logs.cw_log_group_provider import LogGroupProvider +from samcli.lib.observability.cw_logs.cw_log_puller import CWLogPuller +from samcli.lib.observability.observability_info_puller import ( + ObservabilityPuller, + ObservabilityEventConsumerDecorator, + ObservabilityEventConsumer, + ObservabilityCombinedPuller, +) +from samcli.lib.utils.boto_utils import BotoProviderType +from samcli.lib.utils.cloudformation import CloudFormationResourceSummary +from samcli.lib.utils.colors import Colored + +LOG = logging.getLogger(__name__) + + +class NoPullerGeneratedException(UserException): + """ + Used to indicate that no puller information have been generated + therefore there is no observability information (logs, xray) to pull + """ + + +def generate_puller( + boto_client_provider: BotoProviderType, + resource_information_list: List[CloudFormationResourceSummary], + filter_pattern: Optional[str] = None, + additional_cw_log_groups: Optional[List[str]] = None, + unformatted: bool = False, + include_tracing: bool = False, +) -> ObservabilityPuller: + """ + This function will generate generic puller which can be used to + pull information from various observability resources. + + Parameters + ---------- + boto_client_provider: BotoProviderType + Boto3 client generator, which will create a new instance of the client with a new session that could be + used within different threads/coroutines + resource_information_list : List[CloudFormationResourceSummary] + List of resource information, which keeps logical id, physical id and type of the resources + filter_pattern : Optional[str] + Optional filter pattern which will be used to filter incoming events + additional_cw_log_groups : Optional[str] + Optional list of additional CloudWatch log groups which will be used to fetch + log events from. + unformatted : bool + By default, logs and traces are printed with a format for terminal. If this option is provided, the events + will be printed unformatted in JSON. + include_tracing: bool + A flag to include the xray traces log or not + + Returns + ------- + Puller instance that can be used to pull information. + """ + if additional_cw_log_groups is None: + additional_cw_log_groups = [] + pullers: List[ObservabilityPuller] = [] + + # populate all puller instances for given resources + for resource_information in resource_information_list: + cw_log_group_name = LogGroupProvider.for_resource( + boto_client_provider, + resource_information.resource_type, + resource_information.physical_resource_id, + ) + if not cw_log_group_name: + LOG.debug("Can't find CloudWatch LogGroup name for resource (%s)", resource_information.logical_resource_id) + continue + + consumer = generate_consumer(filter_pattern, unformatted, resource_information.logical_resource_id) + pullers.append( + CWLogPuller( + boto_client_provider("logs"), + consumer, + cw_log_group_name, + resource_information.logical_resource_id, + ) + ) + + # populate puller instances for the additional CloudWatch log groups + for cw_log_group in additional_cw_log_groups: + consumer = generate_consumer(filter_pattern, unformatted) + pullers.append( + CWLogPuller( + boto_client_provider("logs"), + consumer, + cw_log_group, + ) + ) + + # if tracing flag is set, add the xray traces puller to fetch debug traces + if include_tracing: + trace_puller = generate_trace_puller(boto_client_provider("xray"), unformatted) + pullers.append(trace_puller) + + # if no puller have been collected, raise an exception since there is nothing to pull + if not pullers: + raise NoPullerGeneratedException("No valid resources find to pull information") + + # return the combined puller instance, which will pull from all pullers collected + return ObservabilityCombinedPuller(pullers) + + +def generate_consumer( + filter_pattern: Optional[str] = None, unformatted: bool = False, resource_name: Optional[str] = None +): + """ + Generates consumer instance with the given variables. + If unformatted is True, then it will return consumer with formatters for just JSON. + If not, it will return console consumer + """ + if unformatted: + return generate_unformatted_consumer() + + return generate_console_consumer(filter_pattern) + + +def generate_unformatted_consumer() -> ObservabilityEventConsumer: + """ + Creates event consumer, which prints CW Log Events unformatted as JSON into terminal + + Returns + ------- + ObservabilityEventConsumer which will store events into a file + """ + return ObservabilityEventConsumerDecorator( + [ + CWLogEventJSONMapper(), + ], + CWConsoleEventConsumer(True), + ) + + +def generate_console_consumer(filter_pattern: Optional[str]) -> ObservabilityEventConsumer: + """ + Creates a console event consumer, which is used to display events in the user's console + + Parameters + ---------- + filter_pattern : str + Filter pattern is used to display certain words in a different pattern then + the rest of the messages. + + Returns + ------- + A consumer which will display events into console + """ + colored = Colored() + return ObservabilityEventConsumerDecorator( + [ + CWColorizeErrorsFormatter(colored), + CWJsonFormatter(), + CWKeywordHighlighterFormatter(colored, filter_pattern), + CWPrettyPrintFormatter(colored), + CWAddNewLineIfItDoesntExist(), + ], + CWConsoleEventConsumer(), + ) diff --git a/samcli/commands/package/command.py b/samcli/commands/package/command.py index cc0dc35c5d..ab5a00c063 100644 --- a/samcli/commands/package/command.py +++ b/samcli/commands/package/command.py @@ -1,24 +1,24 @@ """ CLI command for "package" command """ -from functools import partial - import click from samcli.cli.cli_config_file import configuration_option, TomlProvider from samcli.cli.main import pass_context, common_options, aws_creds_options, print_cmdline_args -from samcli.cli.types import ImageRepositoryType, ImageRepositoriesType -from samcli.commands.package.exceptions import PackageResolveS3AndS3SetError, PackageResolveS3AndS3NotSetError from samcli.lib.cli_validation.image_repository_validation import image_repository_validation -from samcli.lib.utils.packagetype import ZIP, IMAGE from samcli.commands._utils.options import ( - artifact_callback, - resolve_s3_callback, signing_profiles_option, - image_repositories_callback, + s3_bucket_option, + image_repository_option, + image_repositories_option, + s3_prefix_option, + kms_key_id_option, + use_json_option, + force_upload_option, + resolve_s3_option, ) -from samcli.commands._utils.options import metadata_override_option, template_click_option, no_progressbar_option -from samcli.commands._utils.resources import resources_generator +from samcli.commands._utils.options import metadata_option, template_click_option, no_progressbar_option +from samcli.lib.utils.resources import resources_generator from samcli.lib.bootstrap.bootstrap import manage_stack from samcli.lib.telemetry.metric import track_command, track_template_warnings from samcli.lib.utils.version_checker import check_newer_version @@ -54,40 +54,6 @@ def resources_and_properties_help_string(): @click.command("package", short_help=SHORT_HELP, help=HELP_TEXT, context_settings=dict(max_content_width=120)) @configuration_option(provider=TomlProvider(section="parameters")) @template_click_option(include_build=True) -@click.option( - "--s3-bucket", - required=False, - callback=partial(artifact_callback, artifact=ZIP), - help="The name of the S3 bucket where this command uploads the artifacts that are referenced in your template.", -) -@click.option( - "--image-repository", - callback=partial(artifact_callback, artifact=IMAGE), - type=ImageRepositoryType(), - required=False, - help="ECR repo uri where this command uploads the image artifacts that are referenced in your template.", -) -@click.option( - "--image-repositories", - multiple=True, - callback=image_repositories_callback, - type=ImageRepositoriesType(), - required=False, - help="Specify mapping of Function Logical ID to ECR Repo uri, of the form Function_Logical_ID=ECR_Repo_Uri." - "This option can be specified multiple times.", -) -@click.option( - "--s3-prefix", - required=False, - help="A prefix name that the command adds to the artifacts " - "name when it uploads them to the S3 bucket. The prefix name is a " - "path name (folder name) for the S3 bucket.", -) -@click.option( - "--kms-key-id", - required=False, - help="The ID of an AWS KMS key that the command uses to encrypt artifacts that are at rest in the S3 bucket.", -) @click.option( "--output-template-file", required=False, @@ -96,37 +62,15 @@ def resources_and_properties_help_string(): "writes the output AWS CloudFormation template. If you don't specify a " "path, the command writes the template to the standard output.", ) -@click.option( - "--use-json", - required=False, - is_flag=True, - help="Indicates whether to use JSON as the format for " - "the output AWS CloudFormation template. YAML is used by default.", -) -@click.option( - "--force-upload", - required=False, - is_flag=True, - help="Indicates whether to override existing files " - "in the S3 bucket. Specify this flag to upload artifacts even if they " - "match existing artifacts in the S3 bucket.", -) -@click.option( - "--resolve-s3", - required=False, - is_flag=True, - callback=partial( - resolve_s3_callback, - artifact=ZIP, - exc_set=PackageResolveS3AndS3SetError, - exc_not_set=PackageResolveS3AndS3NotSetError, - ), - help="Automatically resolve s3 bucket for non-guided deployments. " - "Enabling this option will also create a managed default s3 bucket for you. " - "If you do not provide a --s3-bucket value, the managed bucket will be used. " - "Do not use --s3-guided parameter with this option.", -) -@metadata_override_option +@s3_bucket_option +@image_repository_option +@image_repositories_option +@s3_prefix_option +@kms_key_id_option +@use_json_option +@force_upload_option +@resolve_s3_option +@metadata_option @signing_profiles_option @no_progressbar_option @common_options diff --git a/samcli/commands/package/package_context.py b/samcli/commands/package/package_context.py index 0a26577333..8aaca8241c 100644 --- a/samcli/commands/package/package_context.py +++ b/samcli/commands/package/package_context.py @@ -30,7 +30,7 @@ from samcli.lib.package.code_signer import CodeSigner from samcli.lib.package.s3_uploader import S3Uploader from samcli.lib.package.uploaders import Uploaders -from samcli.lib.utils.botoconfig import get_boto_config_with_user_agent +from samcli.lib.utils.boto_utils import get_boto_config_with_user_agent from samcli.yamlhelper import yaml_dump LOG = logging.getLogger(__name__) diff --git a/samcli/commands/pipeline/init/interactive_init_flow.py b/samcli/commands/pipeline/init/interactive_init_flow.py index d4e989ebfa..31920b474e 100644 --- a/samcli/commands/pipeline/init/interactive_init_flow.py +++ b/samcli/commands/pipeline/init/interactive_init_flow.py @@ -11,8 +11,8 @@ from typing import Dict, List, Tuple import click +from samcli.cli.global_config import GlobalConfig -from samcli.cli.main import global_cfg from samcli.commands.exceptions import ( AppPipelineTemplateMetadataException, PipelineTemplateCloneException, @@ -34,7 +34,7 @@ ) LOG = logging.getLogger(__name__) -shared_path: Path = global_cfg.config_dir +shared_path: Path = GlobalConfig().config_dir APP_PIPELINE_TEMPLATES_REPO_URL = "https://github.com/aws/aws-sam-cli-pipeline-init-templates.git" APP_PIPELINE_TEMPLATES_REPO_LOCAL_NAME = "aws-sam-cli-app-pipeline-templates" CUSTOM_PIPELINE_TEMPLATE_REPO_LOCAL_NAME = "custom-pipeline-template" diff --git a/samcli/commands/sync/__init__.py b/samcli/commands/sync/__init__.py new file mode 100644 index 0000000000..e849905d58 --- /dev/null +++ b/samcli/commands/sync/__init__.py @@ -0,0 +1,4 @@ +"""`sam sync` command.""" + +# Expose the cli object here +from .command import cli # noqa diff --git a/samcli/commands/sync/command.py b/samcli/commands/sync/command.py new file mode 100644 index 0000000000..4f6add0f35 --- /dev/null +++ b/samcli/commands/sync/command.py @@ -0,0 +1,415 @@ +"""CLI command for "sync" command.""" +import os +import logging +from typing import List, Set, TYPE_CHECKING, Optional, Tuple + +import click + +from samcli.cli.main import pass_context, common_options as cli_framework_options, aws_creds_options, print_cmdline_args +from samcli.commands._utils.options import ( + template_option_without_build, + parameter_override_option, + capabilities_option, + metadata_option, + notification_arns_option, + tags_option, + stack_name_option, + base_dir_option, + image_repository_option, + image_repositories_option, + s3_prefix_option, + kms_key_id_option, + role_arn_option, + DEFAULT_BUILD_DIR, + DEFAULT_CACHE_DIR, + DEFAULT_BUILD_DIR_WITH_AUTO_DEPENDENCY_LAYER, +) +from samcli.cli.cli_config_file import configuration_option, TomlProvider +from samcli.lib.utils.colors import Colored +from samcli.lib.utils.version_checker import check_newer_version +from samcli.lib.bootstrap.bootstrap import manage_stack +from samcli.lib.cli_validation.image_repository_validation import image_repository_validation +from samcli.lib.telemetry.metric import track_command, track_template_warnings +from samcli.lib.warnings.sam_cli_warning import CodeDeployWarning, CodeDeployConditionWarning +from samcli.commands.build.command import _get_mode_value_from_envvar +from samcli.lib.sync.sync_flow_factory import SyncFlowFactory +from samcli.lib.sync.sync_flow_executor import SyncFlowExecutor +from samcli.lib.providers.sam_stack_provider import SamLocalStackProvider +from samcli.lib.providers.provider import ( + ResourceIdentifier, + get_all_resource_ids, + get_unique_resource_ids, +) +from samcli.cli.context import Context +from samcli.lib.sync.watch_manager import WatchManager +from samcli.commands._utils.experimental import ( + ExperimentalFlag, + experimental, + is_experimental_enabled, + set_experimental, +) + +if TYPE_CHECKING: # pragma: no cover + from samcli.commands.deploy.deploy_context import DeployContext + from samcli.commands.package.package_context import PackageContext + from samcli.commands.build.build_context import BuildContext + +LOG = logging.getLogger(__name__) + +HELP_TEXT = """ +[Beta Feature] Update/sync local artifacts to AWS + +By default, the sync command runs a full stack update, you can specify --code or --watch to which modes +""" + +SYNC_CONFIRMATION_TEXT = """ +The SAM CLI will use the AWS Lambda, Amazon API Gateway, and AWS StepFunctions APIs to upload your code without +performing a CloudFormation deployment. This will cause drift in your CloudFormation stack. +**The sync command should only be used against a development stack**. +Confirm that you are synchronizing a development stack. + +Enter Y to proceed with the command, or enter N to cancel: +""" + +SYNC_CONFIRMATION_TEXT_WITH_BETA = """ +This feature is currently in beta. Visit the docs page to learn more about the AWS Beta terms https://aws.amazon.com/service-terms/. + +The SAM CLI will use the AWS Lambda, Amazon API Gateway, and AWS StepFunctions APIs to upload your code without +performing a CloudFormation deployment. This will cause drift in your CloudFormation stack. +**The sync command should only be used against a development stack**. + +Confirm that you are synchronizing a development stack and want to turn on beta features. + +Enter Y to proceed with the command, or enter N to cancel: +""" + + +SHORT_HELP = "[Beta Feature] Sync a project to AWS" + +DEFAULT_TEMPLATE_NAME = "template.yaml" +DEFAULT_CAPABILITIES = ("CAPABILITY_NAMED_IAM", "CAPABILITY_AUTO_EXPAND") + + +@click.command("sync", help=HELP_TEXT, short_help=SHORT_HELP) +@configuration_option(provider=TomlProvider(section="parameters")) +@template_option_without_build +@click.option( + "--code", + is_flag=True, + help="Sync code resources. This includes Lambda Functions, API Gateway, and Step Functions.", +) +@click.option( + "--watch", + is_flag=True, + help="Watch local files and automatically sync with remote.", +) +@click.option( + "--resource-id", + multiple=True, + help="Sync code for all the resources with the ID.", +) +@click.option( + "--resource", + multiple=True, + help="Sync code for all types of the resource.", +) +@click.option( + "--dependency-layer/--no-dependency-layer", + default=True, + is_flag=True, + help="This option separates the dependencies of individual function into another layer, for speeding up the sync" + "process", +) +@stack_name_option(required=True) # pylint: disable=E1120 +@base_dir_option +@image_repository_option +@image_repositories_option +@s3_prefix_option +@kms_key_id_option +@role_arn_option +@parameter_override_option +@cli_framework_options +@aws_creds_options +@metadata_option +@notification_arns_option +@tags_option +@capabilities_option(default=DEFAULT_CAPABILITIES) # pylint: disable=E1120 +@experimental +@pass_context +@track_command +@image_repository_validation +@track_template_warnings([CodeDeployWarning.__name__, CodeDeployConditionWarning.__name__]) +@check_newer_version +@print_cmdline_args +def cli( + ctx: Context, + template_file: str, + code: bool, + watch: bool, + resource_id: Optional[Tuple[str]], + resource: Optional[Tuple[str]], + dependency_layer: bool, + stack_name: str, + base_dir: Optional[str], + parameter_overrides: dict, + image_repository: str, + image_repositories: Optional[Tuple[str]], + s3_prefix: str, + kms_key_id: str, + capabilities: Optional[List[str]], + role_arn: Optional[str], + notification_arns: Optional[List[str]], + tags: dict, + metadata: dict, + config_file: str, + config_env: str, +) -> None: + """ + `sam sync` command entry point + """ + mode = _get_mode_value_from_envvar("SAM_BUILD_MODE", choices=["debug"]) + # All logic must be implemented in the ``do_cli`` method. This helps with easy unit testing + + do_cli( + template_file, + code, + watch, + resource_id, + resource, + dependency_layer, + stack_name, + ctx.region, + ctx.profile, + base_dir, + parameter_overrides, + mode, + image_repository, + image_repositories, + s3_prefix, + kms_key_id, + capabilities, + role_arn, + notification_arns, + tags, + metadata, + config_file, + config_env, + ) # pragma: no cover + + +def do_cli( + template_file: str, + code: bool, + watch: bool, + resource_id: Optional[Tuple[str]], + resource: Optional[Tuple[str]], + dependency_layer: bool, + stack_name: str, + region: str, + profile: str, + base_dir: Optional[str], + parameter_overrides: dict, + mode: Optional[str], + image_repository: str, + image_repositories: Optional[Tuple[str]], + s3_prefix: str, + kms_key_id: str, + capabilities: Optional[List[str]], + role_arn: Optional[str], + notification_arns: Optional[List[str]], + tags: dict, + metadata: dict, + config_file: str, + config_env: str, +) -> None: + """ + Implementation of the ``cli`` method + """ + from samcli.lib.utils import osutils + from samcli.commands.build.build_context import BuildContext + from samcli.commands.package.package_context import PackageContext + from samcli.commands.deploy.deploy_context import DeployContext + + s3_bucket = manage_stack(profile=profile, region=region) + click.echo(f"\n\t\tManaged S3 bucket: {s3_bucket}") + click.echo("\t\tA different default S3 bucket can be set in samconfig.toml") + click.echo("\t\tOr by specifying --s3-bucket explicitly.") + + click.echo(f"\n\t\tDefault capabilities applied: {DEFAULT_CAPABILITIES}") + click.echo("To override with customized capabilities, use --capabitilies flag or set it in samconfig.toml") + + build_dir = DEFAULT_BUILD_DIR_WITH_AUTO_DEPENDENCY_LAYER if dependency_layer else DEFAULT_BUILD_DIR + LOG.debug("Using build directory as %s", build_dir) + + build_dir = DEFAULT_BUILD_DIR_WITH_AUTO_DEPENDENCY_LAYER if dependency_layer else DEFAULT_BUILD_DIR + LOG.debug("Using build directory as %s", build_dir) + + confirmation_text = SYNC_CONFIRMATION_TEXT + + if not is_experimental_enabled(ExperimentalFlag.Accelerate): + confirmation_text = SYNC_CONFIRMATION_TEXT_WITH_BETA + + if not click.confirm(Colored().yellow(confirmation_text), default=False): + return + + set_experimental(ExperimentalFlag.Accelerate) + + with BuildContext( + resource_identifier=None, + template_file=template_file, + base_dir=base_dir, + build_dir=build_dir, + cache_dir=DEFAULT_CACHE_DIR, + clean=True, + use_container=False, + cached=True, + parallel=True, + parameter_overrides=parameter_overrides, + mode=mode, + create_auto_dependency_layer=dependency_layer, + stack_name=stack_name, + ) as build_context: + built_template = os.path.join(build_dir, DEFAULT_TEMPLATE_NAME) + + with osutils.tempfile_platform_independent() as output_template_file: + with PackageContext( + template_file=built_template, + s3_bucket=s3_bucket, + image_repository=image_repository, + image_repositories=image_repositories, + s3_prefix=s3_prefix, + kms_key_id=kms_key_id, + output_template_file=output_template_file.name, + no_progressbar=True, + metadata=metadata, + region=region, + profile=profile, + use_json=False, + force_upload=True, + ) as package_context: + + with DeployContext( + template_file=output_template_file.name, + stack_name=stack_name, + s3_bucket=s3_bucket, + image_repository=image_repository, + image_repositories=image_repositories, + no_progressbar=True, + s3_prefix=s3_prefix, + kms_key_id=kms_key_id, + parameter_overrides=parameter_overrides, + capabilities=capabilities, + role_arn=role_arn, + notification_arns=notification_arns, + tags=tags, + region=region, + profile=profile, + no_execute_changeset=True, + fail_on_empty_changeset=True, + confirm_changeset=False, + use_changeset=False, + force_upload=True, + signing_profiles=None, + disable_rollback=False, + ) as deploy_context: + if watch: + execute_watch(template_file, build_context, package_context, deploy_context, dependency_layer) + elif code: + execute_code_sync( + template_file, build_context, deploy_context, resource_id, resource, dependency_layer + ) + else: + execute_infra_contexts(build_context, package_context, deploy_context) + + +def execute_infra_contexts( + build_context: "BuildContext", + package_context: "PackageContext", + deploy_context: "DeployContext", +) -> None: + """Executes the sync for infra. + + Parameters + ---------- + build_context : BuildContext + BuildContext + package_context : PackageContext + PackageContext + deploy_context : DeployContext + DeployContext + """ + LOG.debug("Executing the build using build context.") + build_context.run() + LOG.debug("Executing the packaging using package context.") + package_context.run() + LOG.debug("Executing the deployment using deploy context.") + deploy_context.run() + + +def execute_code_sync( + template: str, + build_context: "BuildContext", + deploy_context: "DeployContext", + resource_ids: Optional[Tuple[str]], + resource_types: Optional[Tuple[str]], + auto_dependency_layer: bool, +) -> None: + """Executes the sync flow for code. + + Parameters + ---------- + template : str + Template file name + build_context : BuildContext + BuildContext + deploy_context : DeployContext + DeployContext + resource_ids : List[str] + List of resource IDs to be synced. + resource_types : List[str] + List of resource types to be synced. + auto_dependency_layer: bool + Boolean flag to whether enable certain sync flows for auto dependency layer feature + """ + stacks = SamLocalStackProvider.get_stacks(template)[0] + factory = SyncFlowFactory(build_context, deploy_context, stacks, auto_dependency_layer) + factory.load_physical_id_mapping() + executor = SyncFlowExecutor() + + sync_flow_resource_ids: Set[ResourceIdentifier] = ( + get_unique_resource_ids(stacks, resource_ids, resource_types) + if resource_ids or resource_types + else set(get_all_resource_ids(stacks)) + ) + + for resource_id in sync_flow_resource_ids: + sync_flow = factory.create_sync_flow(resource_id) + if sync_flow: + executor.add_sync_flow(sync_flow) + else: + LOG.warning("Cannot create SyncFlow for %s. Skipping.", resource_id) + executor.execute() + + +def execute_watch( + template: str, + build_context: "BuildContext", + package_context: "PackageContext", + deploy_context: "DeployContext", + auto_dependency_layer: bool, +): + """Start sync watch execution + + Parameters + ---------- + template : str + Template file path + build_context : BuildContext + BuildContext + package_context : PackageContext + PackageContext + deploy_context : DeployContext + DeployContext + """ + watch_manager = WatchManager(template, build_context, package_context, deploy_context, auto_dependency_layer) + watch_manager.start() diff --git a/samcli/commands/traces/__init__.py b/samcli/commands/traces/__init__.py new file mode 100644 index 0000000000..596c05f79b --- /dev/null +++ b/samcli/commands/traces/__init__.py @@ -0,0 +1,6 @@ +""" +`sam traces` command +""" + +# Expose the cli object here +from samcli.commands.traces.command import cli diff --git a/samcli/commands/traces/command.py b/samcli/commands/traces/command.py new file mode 100644 index 0000000000..b2fab56f5b --- /dev/null +++ b/samcli/commands/traces/command.py @@ -0,0 +1,78 @@ +""" +CLI command for "traces" command +""" +import logging + +import click + +from samcli.cli.cli_config_file import configuration_option, TomlProvider +from samcli.cli.main import pass_context, common_options as cli_framework_options, aws_creds_options, print_cmdline_args +from samcli.commands._utils.options import common_observability_options +from samcli.lib.telemetry.metric import track_command +from samcli.lib.utils.version_checker import check_newer_version +from samcli.commands._utils.experimental import ExperimentalFlag, force_experimental + +LOG = logging.getLogger(__name__) + +HELP_TEXT = """ +[Beta Feature] Use this command to fetch AWS X-Ray traces generated by your stack.\n +""" + + +@click.command("traces", help=HELP_TEXT, short_help="[Beta Feature] Fetch AWS X-Ray traces") +@configuration_option(provider=TomlProvider(section="parameters")) +@click.option( + "--trace-id", + "-ti", + multiple=True, + help="Fetch specific trace by providing its id", +) +@common_observability_options +@cli_framework_options +@force_experimental(config_entry=ExperimentalFlag.Accelerate) # pylint: disable=E1120 +@aws_creds_options +@pass_context +@track_command +@check_newer_version +@print_cmdline_args +def cli( + ctx, + trace_id, + start_time, + end_time, + tail, + unformatted, + config_file, + config_env, +): + """ + `sam traces` command entry point + """ + do_cli(trace_id, start_time, end_time, tail, unformatted, ctx.region) + + +def do_cli(trace_ids, start_time, end_time, tailing, unformatted, region): + """ + Implementation of the ``cli`` method + """ + from datetime import datetime + import boto3 + from samcli.commands.logs.logs_context import parse_time + from samcli.commands.traces.traces_puller_factory import generate_trace_puller + from samcli.lib.utils.boto_utils import get_boto_config_with_user_agent + + sanitized_start_time = parse_time(start_time, "start-time") + sanitized_end_time = parse_time(end_time, "end-time") or datetime.utcnow() + + boto_config = get_boto_config_with_user_agent(region_name=region) + xray_client = boto3.client("xray", config=boto_config) + + # generate puller depending on the parameters + puller = generate_trace_puller(xray_client, unformatted) + + if trace_ids: + puller.load_events(trace_ids) + elif tailing: + puller.tail(sanitized_start_time) + else: + puller.load_time_period(sanitized_start_time, sanitized_end_time) diff --git a/samcli/commands/traces/trace_console_consumers.py b/samcli/commands/traces/trace_console_consumers.py new file mode 100644 index 0000000000..9d84383a35 --- /dev/null +++ b/samcli/commands/traces/trace_console_consumers.py @@ -0,0 +1,18 @@ +""" +Contains console consumers for outputting XRay information back to console/terminal +""" + +import click + +from samcli.lib.observability.observability_info_puller import ObservabilityEventConsumer +from samcli.lib.observability.xray_traces.xray_events import XRayTraceEvent + + +class XRayTraceConsoleConsumer(ObservabilityEventConsumer[XRayTraceEvent]): + """ + An XRayTraceEvent consumer which will output incoming XRayTraceEvent and print it back to console + """ + + # pylint: disable=R0201 + def consume(self, event: XRayTraceEvent): + click.echo(event.message) diff --git a/samcli/commands/traces/traces_puller_factory.py b/samcli/commands/traces/traces_puller_factory.py new file mode 100644 index 0000000000..7c3f5a4860 --- /dev/null +++ b/samcli/commands/traces/traces_puller_factory.py @@ -0,0 +1,112 @@ +""" +Factory methods which generates puller and consumer instances for XRay events +""" +from typing import Any, List + +from samcli.commands.traces.trace_console_consumers import XRayTraceConsoleConsumer +from samcli.lib.observability.observability_info_puller import ( + ObservabilityPuller, + ObservabilityEventConsumer, + ObservabilityEventConsumerDecorator, + ObservabilityCombinedPuller, +) +from samcli.lib.observability.xray_traces.xray_event_mappers import ( + XRayTraceConsoleMapper, + XRayServiceGraphConsoleMapper, + XRayServiceGraphJSONMapper, + XRayTraceJSONMapper, +) +from samcli.lib.observability.xray_traces.xray_event_puller import XRayTracePuller +from samcli.lib.observability.xray_traces.xray_service_graph_event_puller import XRayServiceGraphPuller + + +def generate_trace_puller( + xray_client: Any, + unformatted: bool = False, +) -> ObservabilityPuller: + """ + Generates puller instance with correct consumer and/or mapper configuration + + Parameters + ---------- + xray_client : Any + boto3 xray client to be used in XRayTracePuller instance + unformatted : bool + By default, logs and traces are printed with a format for terminal. If this option is provided, the events + will be printed unformatted in JSON. + + Returns + ------- + Puller instance with desired configuration + """ + pullers: List[ObservabilityPuller] = [] + pullers.append(XRayTracePuller(xray_client, generate_xray_event_consumer(unformatted))) + pullers.append(XRayServiceGraphPuller(xray_client, generate_xray_service_graph_consumer(unformatted))) + + return ObservabilityCombinedPuller(pullers) + + +def generate_unformatted_xray_event_consumer() -> ObservabilityEventConsumer: + """ + Generates unformatted consumer, which will print XRay events unformatted JSON into terminal + + Returns + ------- + File consumer instance with desired mapper configuration + """ + return ObservabilityEventConsumerDecorator([XRayTraceJSONMapper()], XRayTraceConsoleConsumer()) + + +def generate_xray_event_console_consumer() -> ObservabilityEventConsumer: + """ + Generates an instance of event consumer which will print events into console + + Returns + ------- + Console consumer instance with desired mapper configuration + """ + return ObservabilityEventConsumerDecorator([XRayTraceConsoleMapper()], XRayTraceConsoleConsumer()) + + +def generate_xray_event_consumer(unformatted: bool = False) -> ObservabilityEventConsumer: + """ + Generates consumer instance with the given variables. + If unformatted is True, then it will return consumer with formatters for just JSON. + If not, it will return console consumer + """ + if unformatted: + return generate_unformatted_xray_event_consumer() + return generate_xray_event_console_consumer() + + +def generate_unformatted_xray_service_graph_consumer() -> ObservabilityEventConsumer: + """ + Generates unformatted consumer, which will print XRay events unformatted JSON into terminal + + Returns + ------- + File consumer instance with desired mapper configuration + """ + return ObservabilityEventConsumerDecorator([XRayServiceGraphJSONMapper()], XRayTraceConsoleConsumer()) + + +def generate_xray_service_graph_console_consumer() -> ObservabilityEventConsumer: + """ + Generates an instance of event consumer which will print events into console + + Returns + ------- + Console consumer instance with desired mapper configuration + """ + return ObservabilityEventConsumerDecorator([XRayServiceGraphConsoleMapper()], XRayTraceConsoleConsumer()) + + +def generate_xray_service_graph_consumer(unformatted: bool = False) -> ObservabilityEventConsumer: + """ + Generates consumer instance with the given variables. + If unformatted is True, then it will return consumer with formatters for just JSON. + If not, it will return console consumer + """ + if unformatted: + return generate_unformatted_xray_service_graph_consumer() + return generate_xray_service_graph_console_consumer() diff --git a/samcli/commands/validate/lib/sam_template_validator.py b/samcli/commands/validate/lib/sam_template_validator.py index d9de756674..ca27ac8c56 100644 --- a/samcli/commands/validate/lib/sam_template_validator.py +++ b/samcli/commands/validate/lib/sam_template_validator.py @@ -10,7 +10,7 @@ from boto3.session import Session from samcli.lib.utils.packagetype import ZIP, IMAGE -from samcli.commands._utils.resources import AWS_SERVERLESS_FUNCTION +from samcli.lib.utils.resources import AWS_SERVERLESS_FUNCTION from samcli.yamlhelper import yaml_dump from .exceptions import InvalidSamDocumentException diff --git a/samcli/lib/bootstrap/companion_stack/companion_stack_builder.py b/samcli/lib/bootstrap/companion_stack/companion_stack_builder.py index 85280c2513..9d4f6fc705 100644 --- a/samcli/lib/bootstrap/companion_stack/companion_stack_builder.py +++ b/samcli/lib/bootstrap/companion_stack/companion_stack_builder.py @@ -1,15 +1,13 @@ """ Companion stack template builder """ -import json - -from typing import Dict +from typing import Dict, cast from samcli.lib.bootstrap.companion_stack.data_types import CompanionStack, ECRRepo -from samcli import __version__ as VERSION +from samcli.lib.bootstrap.stack_builder import AbstractStackBuilder -class CompanionStackBuilder: +class CompanionStackBuilder(AbstractStackBuilder): """ CFN template builder for the companion stack """ @@ -19,8 +17,10 @@ class CompanionStackBuilder: _repo_mapping: Dict[str, ECRRepo] def __init__(self, companion_stack: CompanionStack) -> None: + super().__init__("AWS SAM CLI Managed ECR Repo Stack") self._companion_stack = companion_stack self._repo_mapping: Dict[str, ECRRepo] = dict() + self.add_metadata("CompanionStackname", self._companion_stack.stack_name) def add_function(self, function_logical_id: str) -> None: """ @@ -42,30 +42,11 @@ def build(self) -> str: str CFN template for companions stack """ - template_dict = self._build_template_dict() for _, ecr_repo in self._repo_mapping.items(): - template_dict["Resources"][ecr_repo.logical_id] = self._build_repo_dict(ecr_repo) - template_dict["Outputs"][ecr_repo.output_logical_id] = CompanionStackBuilder._build_output_dict(ecr_repo) - - return json.dumps(template_dict) + self.add_resource(cast(str, ecr_repo.logical_id), self._build_repo_dict(ecr_repo)) + self.add_output(cast(str, ecr_repo.output_logical_id), CompanionStackBuilder._build_output_dict(ecr_repo)) - def _build_template_dict(self) -> Dict: - """ - Build Companion stack template dictionary with Resources and Outputs not filled - Returns - ------- - dict - Companion stack template dictionary - """ - template = { - "AWSTemplateFormatVersion": "2010-09-09", - "Transform": "AWS::Serverless-2016-10-31", - "Description": "AWS SAM CLI Managed ECR Repo Stack", - "Metadata": {"SamCliInfo": VERSION, "CompanionStackname": self._companion_stack.stack_name}, - "Resources": {}, - "Outputs": {}, - } - return template + return super().build() def _build_repo_dict(self, repo: ECRRepo) -> Dict: """ @@ -104,7 +85,7 @@ def _build_repo_dict(self, repo: ECRRepo) -> Dict: } @staticmethod - def _build_output_dict(repo: ECRRepo) -> Dict: + def _build_output_dict(repo: ECRRepo) -> str: """ Build a single ECR repo output resource dictionary @@ -118,9 +99,7 @@ def _build_output_dict(repo: ECRRepo) -> Dict: dict ECR repo output resource dictionary """ - return { - "Value": f"!Sub ${{AWS::AccountId}}.dkr.ecr.${{AWS::Region}}.${{AWS::URLSuffix}}/${{{repo.logical_id}}}" - } + return f"!Sub ${{AWS::AccountId}}.dkr.ecr.${{AWS::Region}}.${{AWS::URLSuffix}}/${{{repo.logical_id}}}" @property def repo_mapping(self) -> Dict[str, ECRRepo]: diff --git a/samcli/lib/bootstrap/nested_stack/__init__.py b/samcli/lib/bootstrap/nested_stack/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/samcli/lib/bootstrap/nested_stack/nested_stack_builder.py b/samcli/lib/bootstrap/nested_stack/nested_stack_builder.py new file mode 100644 index 0000000000..a668226bd1 --- /dev/null +++ b/samcli/lib/bootstrap/nested_stack/nested_stack_builder.py @@ -0,0 +1,78 @@ +""" +StackBuilder implementation for nested stack +""" +from typing import cast + +from samcli.lib.bootstrap.stack_builder import AbstractStackBuilder +from samcli.lib.providers.provider import Function +from samcli.lib.utils.hash import str_checksum +from samcli.lib.utils.resources import AWS_SERVERLESS_LAYERVERSION, AWS_CLOUDFORMATION_STACK + +CREATED_BY_METADATA_KEY = "CreatedBy" +CREATED_BY_METADATA_VALUE = "AWS SAM CLI sync command" + + +class NestedStackBuilder(AbstractStackBuilder): + """ + CFN/SAM Template creator for nested stack + """ + + def __init__(self): + super().__init__("AWS SAM CLI Nested Stack for Auto Dependency Layer Creation") + self.add_metadata(CREATED_BY_METADATA_KEY, CREATED_BY_METADATA_VALUE) + + def is_any_function_added(self) -> bool: + return bool(self._template_dict.get("Resources", {})) + + def add_function( + self, + stack_name: str, + layer_contents_folder: str, + function: Function, + ) -> str: + layer_logical_id = self.get_layer_logical_id(function.name) + layer_name = self.get_layer_name(stack_name, function.name) + + self.add_resource( + layer_logical_id, + self._get_layer_dict(function.name, layer_name, layer_contents_folder, cast(str, function.runtime)), + ) + self.add_output(layer_logical_id, {"Ref": layer_logical_id}) + return layer_logical_id + + @staticmethod + def get_layer_logical_id(function_logical_id: str) -> str: + function_logical_id_hash = str_checksum(function_logical_id) + return f"{function_logical_id[:48]}{function_logical_id_hash[:8]}DepLayer" + + @staticmethod + def get_layer_name(stack_name: str, function_logical_id: str) -> str: + function_logical_id_hash = str_checksum(function_logical_id) + stack_name_hash = str_checksum(stack_name) + return ( + f"{stack_name[:16]}{stack_name_hash[:8]}-{function_logical_id[:22]}{function_logical_id_hash[:8]}" + f"-DepLayer" + ) + + @staticmethod + def _get_layer_dict(function_logical_id: str, layer_name: str, layer_contents_folder: str, function_runtime: str): + return { + "Type": AWS_SERVERLESS_LAYERVERSION, + "Properties": { + "LayerName": layer_name, + "Description": f"Auto created layer for dependencies of function {function_logical_id}", + "ContentUri": layer_contents_folder, + "RetentionPolicy": "Delete", + "CompatibleRuntimes": [function_runtime], + }, + "Metadata": {CREATED_BY_METADATA_KEY: CREATED_BY_METADATA_VALUE}, + } + + @staticmethod + def get_nested_stack_reference_resource(nested_template_location): + return { + "Type": AWS_CLOUDFORMATION_STACK, + "DeletionPolicy": "Delete", + "Properties": {"TemplateURL": nested_template_location}, + "Metadata": {CREATED_BY_METADATA_KEY: CREATED_BY_METADATA_VALUE}, + } diff --git a/samcli/lib/bootstrap/nested_stack/nested_stack_manager.py b/samcli/lib/bootstrap/nested_stack/nested_stack_manager.py new file mode 100644 index 0000000000..63d241ad43 --- /dev/null +++ b/samcli/lib/bootstrap/nested_stack/nested_stack_manager.py @@ -0,0 +1,194 @@ +""" +nested stack manager to generate nested stack information and update original template with it +""" +import logging +import os +import shutil +from copy import deepcopy +from pathlib import Path +from typing import Dict, Optional, cast + +from samcli.commands._utils.template import move_template +from samcli.lib.bootstrap.nested_stack.nested_stack_builder import NestedStackBuilder +from samcli.lib.build.app_builder import ApplicationBuildResult +from samcli.lib.build.workflow_config import get_layer_subfolder +from samcli.lib.providers.provider import Stack, Function +from samcli.lib.providers.sam_function_provider import SamFunctionProvider +from samcli.lib.sync.exceptions import InvalidRuntimeDefinitionForFunction +from samcli.lib.utils import osutils +from samcli.lib.utils.osutils import BUILD_DIR_PERMISSIONS +from samcli.lib.utils.packagetype import ZIP +from samcli.lib.utils.resources import AWS_SERVERLESS_FUNCTION, AWS_LAMBDA_FUNCTION + +LOG = logging.getLogger(__name__) + +# Resource name of the CFN stack +NESTED_STACK_NAME = "AwsSamAutoDependencyLayerNestedStack" + +# Resources which we support creating dependency layer +SUPPORTED_RESOURCES = {AWS_SERVERLESS_FUNCTION, AWS_LAMBDA_FUNCTION} + +# Languages which we support creating dependency layer +SUPPORTED_LANGUAGES = ("python", "nodejs", "java") + + +class NestedStackManager: + + _stack_name: str + _build_dir: str + _stack_location: str + _current_template: Dict + _app_build_result: ApplicationBuildResult + _nested_stack_builder: NestedStackBuilder + + def __init__( + self, + stack_name: str, + build_dir: str, + stack_location: str, + current_template: Dict, + app_build_result: ApplicationBuildResult, + ): + """ + Parameters + ---------- + stack_name : str + Original stack name, which is used to generate layer name + build_dir : str + Build directory for storing the new nested stack template + stack_location : str + Used to move template and its resources' relative path information + current_template : Dict + Current template of the project + app_build_result: ApplicationBuildResult + Application build result, which contains build graph, and built artifacts information + """ + self._stack_name = stack_name + self._build_dir = build_dir + self._stack_location = stack_location + self._current_template = current_template + self._app_build_result = app_build_result + self._nested_stack_builder = NestedStackBuilder() + + def generate_auto_dependency_layer_stack(self) -> Dict: + """ + Loops through all resources, and for the supported ones (SUPPORTED_RESOURCES and SUPPORTED_LANGUAGES) + creates layer for its dependencies in a nested stack, and adds reference of the nested stack back to original + stack + """ + template = deepcopy(self._current_template) + resources = template.get("Resources", {}) + + stack = Stack("", self._stack_name, self._stack_location, {}, template_dict=template) + function_provider = SamFunctionProvider([stack], ignore_code_extraction_warnings=True) + zip_functions = [function for function in function_provider.get_all() if function.packagetype == ZIP] + + for zip_function in zip_functions: + if not self._is_function_supported(zip_function): + continue + + dependencies_dir = self._get_dependencies_dir(zip_function.name) + if not dependencies_dir: + LOG.debug( + "Dependency folder can't be found for %s, skipping auto dependency layer creation", + zip_function.name, + ) + continue + + self._add_layer(dependencies_dir, zip_function, resources) + + if not self._nested_stack_builder.is_any_function_added(): + LOG.debug("No function has been added for auto dependency layer creation") + return template + + nested_template_location = os.path.join(self._build_dir, "nested_template.yaml") + move_template(self._stack_location, nested_template_location, self._nested_stack_builder.build_as_dict()) + + resources[NESTED_STACK_NAME] = self._nested_stack_builder.get_nested_stack_reference_resource( + nested_template_location + ) + return template + + def _add_layer(self, dependencies_dir: str, function: Function, resources: Dict): + layer_logical_id = NestedStackBuilder.get_layer_logical_id(function.name) + layer_location = self.update_layer_folder( + self._build_dir, dependencies_dir, layer_logical_id, function.name, function.runtime + ) + + layer_output_key = self._nested_stack_builder.add_function(self._stack_name, layer_location, function) + + # add layer reference back to function + function_properties = cast(Dict, resources.get(function.name)).get("Properties", {}) + function_layers = function_properties.get("Layers", []) + function_layers.append({"Fn::GetAtt": [NESTED_STACK_NAME, f"Outputs.{layer_output_key}"]}) + function_properties["Layers"] = function_layers + + @staticmethod + def _add_layer_readme_info(dependencies_dir: str, function_name: str): + # add a simple README file for discoverability + with open(os.path.join(dependencies_dir, "AWS_SAM_CLI_README"), "w+") as f: + f.write( + f"This layer contains dependencies of function {function_name} " + "and automatically added by AWS SAM CLI command 'sam sync'" + ) + + @staticmethod + def update_layer_folder( + build_dir: str, + dependencies_dir: str, + layer_logical_id: str, + function_logical_id: str, + function_runtime: Optional[str], + ) -> str: + """ + Creates build folder for auto dependency layer by moving dependencies into sub folder which is defined + by the runtime + """ + if not function_runtime: + raise InvalidRuntimeDefinitionForFunction(function_logical_id) + + layer_root_folder = Path(build_dir).joinpath(layer_logical_id) + if layer_root_folder.exists(): + shutil.rmtree(layer_root_folder) + layer_contents_folder = layer_root_folder.joinpath(get_layer_subfolder(function_runtime)) + layer_contents_folder.mkdir(BUILD_DIR_PERMISSIONS, parents=True) + if os.path.isdir(dependencies_dir): + osutils.copytree(dependencies_dir, str(layer_contents_folder)) + NestedStackManager._add_layer_readme_info(str(layer_root_folder), function_logical_id) + return str(layer_root_folder) + + def _is_function_supported(self, function: Function): + """ + Checks if function is built with current session and its runtime is supported + """ + # check if function is built + if function.name not in self._app_build_result.artifacts.keys(): + LOG.debug( + "Function %s is not built within SAM CLI, skipping for auto dependency layer creation", + function.name, + ) + return False + + return self.is_runtime_supported(function.runtime) + + @staticmethod + def is_runtime_supported(runtime: Optional[str]) -> bool: + # check if runtime/language is supported + if not runtime or not runtime.startswith(SUPPORTED_LANGUAGES): + LOG.debug( + "Runtime %s is not supported for auto dependency layer creation", + runtime, + ) + return False + + return True + + def _get_dependencies_dir(self, function_logical_id: str) -> Optional[str]: + """ + Returns dependency directory information for function + """ + function_build_definition = self._app_build_result.build_graph.get_function_build_definition_with_logical_id( + function_logical_id + ) + + return function_build_definition.dependencies_dir if function_build_definition else None diff --git a/samcli/lib/bootstrap/stack_builder.py b/samcli/lib/bootstrap/stack_builder.py new file mode 100644 index 0000000000..0654dd011f --- /dev/null +++ b/samcli/lib/bootstrap/stack_builder.py @@ -0,0 +1,58 @@ +""" +Abstract definitions for stack builder +""" +import json +from abc import ABC +from copy import deepcopy +from typing import Dict, Union, cast + +from samcli import __version__ as VERSION + +METADATA_FIELD = "Metadata" +RESOURCES_FIELD = "Resources" +OUTPUTS_FIELD = "Outputs" + +DEFAULT_TEMPLATE_BEGINNER = { + "AWSTemplateFormatVersion": "2010-09-09", + "Transform": "AWS::Serverless-2016-10-31", + METADATA_FIELD: {"SamCliInfo": VERSION}, + RESOURCES_FIELD: {}, + OUTPUTS_FIELD: {}, +} + + +class AbstractStackBuilder(ABC): + """ + AbstractStackBuilder implementation which holds common methods for adding resources/properties + and generating SAM template + """ + + _template_dict: Dict + + def __init__(self, description: str): + self._template_dict = deepcopy(DEFAULT_TEMPLATE_BEGINNER) + self._template_dict["Description"] = description + + def add_metadata(self, key: str, value: Union[str, Dict]) -> None: + if METADATA_FIELD not in self._template_dict: + self._template_dict[METADATA_FIELD] = {} + metadata = cast(Dict, self._template_dict.get(METADATA_FIELD)) + metadata["key"] = value + + def add_resource(self, resource_name: str, resource_dict: Dict) -> None: + if RESOURCES_FIELD not in self._template_dict: + self._template_dict[RESOURCES_FIELD] = {} + resources = cast(Dict, self._template_dict.get(RESOURCES_FIELD)) + resources[resource_name] = resource_dict + + def add_output(self, output_name: str, output_value: Union[Dict, str]) -> None: + if OUTPUTS_FIELD not in self._template_dict: + self._template_dict[OUTPUTS_FIELD] = {} + outputs = cast(Dict, self._template_dict.get(OUTPUTS_FIELD)) + outputs[output_name] = {"Value": output_value} + + def build_as_dict(self) -> Dict: + return deepcopy(self._template_dict) + + def build(self) -> str: + return json.dumps(self._template_dict, indent=2) diff --git a/samcli/lib/build/app_builder.py b/samcli/lib/build/app_builder.py index 5930e3f685..0cf3b67954 100644 --- a/samcli/lib/build/app_builder.py +++ b/samcli/lib/build/app_builder.py @@ -6,11 +6,14 @@ import json import logging import pathlib -from typing import List, Optional, Dict, cast, Union +from typing import List, Optional, Dict, cast, Union, NamedTuple import docker import docker.errors -from aws_lambda_builders import RPC_PROTOCOL_VERSION as lambda_builders_protocol_version +from aws_lambda_builders import ( + RPC_PROTOCOL_VERSION as lambda_builders_protocol_version, + __version__ as lambda_builders_version, +) from aws_lambda_builders.builder import LambdaBuilder from aws_lambda_builders.exceptions import LambdaBuilderError @@ -18,13 +21,20 @@ from samcli.lib.build.build_graph import FunctionBuildDefinition, LayerBuildDefinition, BuildGraph from samcli.lib.build.build_strategy import ( DefaultBuildStrategy, - CachedBuildStrategy, + CachedOrIncrementalBuildStrategyWrapper, ParallelBuildStrategy, BuildStrategy, ) +from samcli.lib.utils.resources import ( + AWS_CLOUDFORMATION_STACK, + AWS_LAMBDA_FUNCTION, + AWS_LAMBDA_LAYERVERSION, + AWS_SERVERLESS_APPLICATION, + AWS_SERVERLESS_FUNCTION, + AWS_SERVERLESS_LAYERVERSION, +) from samcli.lib.docker.log_streamer import LogStreamer, LogStreamError from samcli.lib.providers.provider import ResourcesToBuildCollector, Function, get_full_path, Stack, LayerVersion -from samcli.lib.providers.sam_base_provider import SamBaseProvider from samcli.lib.utils.colors import Colored from samcli.lib.utils import osutils from samcli.lib.utils.packagetype import IMAGE, ZIP @@ -47,6 +57,15 @@ LOG = logging.getLogger(__name__) +class ApplicationBuildResult(NamedTuple): + """ + Result of the application build, build_graph and the built artifacts in dictionary + """ + + build_graph: BuildGraph + artifacts: Dict[str, str] + + class ApplicationBuilder: """ Class to build an entire application. Currently, this class builds Lambda functions only, but there is nothing that @@ -71,6 +90,7 @@ def __init__( container_env_var: Optional[Dict] = None, container_env_var_file: Optional[str] = None, build_images: Optional[Dict] = None, + combine_dependencies: bool = True, ) -> None: """ Initialize the class @@ -109,6 +129,9 @@ def __init__( An optional path to file that contains environment variables to pass to the container build_images : Optional[Dict] An optional dictionary of build images to be used for building functions + combine_dependencies: bool + An optional bool parameter to inform lambda builders whether we should separate the source code and + dependencies or not. """ self._resources_to_build = resources_to_build self._build_dir = build_dir @@ -129,15 +152,17 @@ def __init__( self._container_env_var = container_env_var self._container_env_var_file = container_env_var_file self._build_images = build_images or {} + self._combine_dependencies = combine_dependencies - def build(self) -> Dict[str, str]: + def build(self) -> ApplicationBuildResult: """ Build the entire application Returns ------- - dict - Returns the path to where each resource was built as a map of resource's LogicalId to the path string + ApplicationBuildResult + Returns the build graph and the path to where each resource was built as a map of resource's LogicalId + to the path string """ build_graph = self._get_build_graph(self._container_env_var, self._container_env_var_file) build_strategy: BuildStrategy = DefaultBuildStrategy( @@ -148,28 +173,30 @@ def build(self) -> Dict[str, str]: if self._cached: build_strategy = ParallelBuildStrategy( build_graph, - CachedBuildStrategy( + CachedOrIncrementalBuildStrategyWrapper( build_graph, build_strategy, self._base_dir, self._build_dir, self._cache_dir, + self._manifest_path_override, self._is_building_specific_resource, ), ) else: build_strategy = ParallelBuildStrategy(build_graph, build_strategy) elif self._cached: - build_strategy = CachedBuildStrategy( + build_strategy = CachedOrIncrementalBuildStrategyWrapper( build_graph, build_strategy, self._base_dir, self._build_dir, self._cache_dir, + self._manifest_path_override, self._is_building_specific_resource, ) - return build_strategy.build() + return ApplicationBuildResult(build_graph, build_strategy.build()) def _get_build_graph( self, inline_env_vars: Optional[Dict] = None, env_vars_file: Optional[str] = None @@ -279,26 +306,26 @@ def update_template( store_path = os.path.relpath(absolute_output_path, original_dir) if has_build_artifact: - if resource_type == SamBaseProvider.SERVERLESS_FUNCTION and properties.get("PackageType", ZIP) == ZIP: + if resource_type == AWS_SERVERLESS_FUNCTION and properties.get("PackageType", ZIP) == ZIP: properties["CodeUri"] = store_path - if resource_type == SamBaseProvider.LAMBDA_FUNCTION and properties.get("PackageType", ZIP) == ZIP: + if resource_type == AWS_LAMBDA_FUNCTION and properties.get("PackageType", ZIP) == ZIP: properties["Code"] = store_path - if resource_type in [SamBaseProvider.SERVERLESS_LAYER, SamBaseProvider.LAMBDA_LAYER]: + if resource_type in [AWS_SERVERLESS_LAYERVERSION, AWS_LAMBDA_LAYERVERSION]: properties["ContentUri"] = store_path - if resource_type == SamBaseProvider.LAMBDA_FUNCTION and properties.get("PackageType", ZIP) == IMAGE: + if resource_type == AWS_LAMBDA_FUNCTION and properties.get("PackageType", ZIP) == IMAGE: properties["Code"] = built_artifacts[full_path] - if resource_type == SamBaseProvider.SERVERLESS_FUNCTION and properties.get("PackageType", ZIP) == IMAGE: + if resource_type == AWS_SERVERLESS_FUNCTION and properties.get("PackageType", ZIP) == IMAGE: properties["ImageUri"] = built_artifacts[full_path] if is_stack: - if resource_type == SamBaseProvider.SERVERLESS_APPLICATION: + if resource_type == AWS_SERVERLESS_APPLICATION: properties["Location"] = store_path - if resource_type == SamBaseProvider.CLOUDFORMATION_STACK: + if resource_type == AWS_CLOUDFORMATION_STACK: properties["TemplateURL"] = store_path return template_dict @@ -398,6 +425,8 @@ def _build_layer( architecture: str, artifact_dir: str, container_env_vars: Optional[Dict] = None, + dependencies_dir: Optional[str] = None, + download_dependencies: bool = True, ) -> str: """ Given the layer information, this method will build the Lambda layer. Depending on the configuration @@ -407,25 +436,25 @@ def _build_layer( ---------- layer_name : str Name or LogicalId of the function - codeuri : str Path to where the code lives - specified_workflow : str The specified workflow - compatible_runtimes : List[str] List of runtimes the layer build is compatible with - architecture : str The architecture type 'x86_64' and 'arm64' in AWS - artifact_dir : str Path to where layer will be build into. A subfolder will be created in this directory depending on the specified workflow. - container_env_vars : Optional[Dict] An optional dictionary of environment variables to pass to the container. + dependencies_dir: Optional[str] + An optional string parameter which will be used in lambda builders for downloading dependencies into + separate folder + download_dependencies: bool + An optional boolean parameter to inform lambda builders whether download dependencies or use previously + downloaded ones. Default value is True. Returns ------- @@ -480,6 +509,8 @@ def _build_layer( build_runtime, architecture, options, + dependencies_dir, + download_dependencies, ) # Not including subfolder in return so that we copy subfolder, instead of copying artifacts inside it. @@ -496,6 +527,8 @@ def _build_function( # pylint: disable=R1710 artifact_dir: str, metadata: Optional[Dict] = None, container_env_vars: Optional[Dict] = None, + dependencies_dir: Optional[str] = None, + download_dependencies: bool = True, ) -> str: """ Given the function information, this method will build the Lambda function. Depending on the configuration @@ -521,6 +554,12 @@ def _build_function( # pylint: disable=R1710 AWS Lambda function metadata container_env_vars : Optional[Dict] An optional dictionary of environment variables to pass to the container. + dependencies_dir: Optional[str] + An optional string parameter which will be used in lambda builders for downloading dependencies into + separate folder + download_dependencies: bool + An optional boolean parameter to inform lambda builders whether download dependencies or use previously + downloaded ones. Default value is True. Returns ------- @@ -575,7 +614,16 @@ def _build_function( # pylint: disable=R1710 ) return self._build_function_in_process( - config, code_dir, artifact_dir, scratch_dir, manifest_path, runtime, architecture, options + config, + code_dir, + artifact_dir, + scratch_dir, + manifest_path, + runtime, + architecture, + options, + dependencies_dir, + download_dependencies, ) # pylint: disable=fixme @@ -615,6 +663,8 @@ def _build_function_in_process( runtime: str, architecture: str, options: Optional[Dict], + dependencies_dir: Optional[str], + download_dependencies: bool, ) -> str: builder = LambdaBuilder( @@ -636,6 +686,9 @@ def _build_function_in_process( mode=self._mode, options=options, architecture=architecture, + dependencies_dir=dependencies_dir, + download_dependencies=download_dependencies, + combine_dependencies=self._combine_dependencies, ) except LambdaBuilderError as ex: raise BuildError(wrapped_from=ex.__class__.__name__, msg=str(ex)) from ex diff --git a/samcli/lib/build/build_graph.py b/samcli/lib/build/build_graph.py index d719cedb83..bec5df43cd 100644 --- a/samcli/lib/build/build_graph.py +++ b/samcli/lib/build/build_graph.py @@ -2,10 +2,13 @@ Holds classes and utility methods related to build graph """ +import copy import logging -from copy import deepcopy +import os +import threading from pathlib import Path -from typing import Tuple, List, Any, Optional, Dict, cast +from typing import Sequence, Tuple, List, Any, Optional, Dict, cast, NamedTuple +from copy import deepcopy from uuid import uuid4 import tomlkit @@ -20,13 +23,16 @@ DEFAULT_BUILD_GRAPH_FILE_NAME = "build.toml" +DEFAULT_DEPENDENCIES_DIR = os.path.join(".aws-sam", "deps") + # filed names for the toml table PACKAGETYPE_FIELD = "packagetype" CODE_URI_FIELD = "codeuri" RUNTIME_FIELD = "runtime" METADATA_FIELD = "metadata" FUNCTIONS_FIELD = "functions" -SOURCE_MD5_FIELD = "source_md5" +SOURCE_HASH_FIELD = "source_hash" +MANIFEST_HASH_FIELD = "manifest_hash" ENV_VARS_FIELD = "env_vars" LAYER_NAME_FIELD = "layer_name" BUILD_METHOD_FIELD = "build_method" @@ -55,8 +61,10 @@ def _function_build_definition_to_toml_table( if function_build_definition.packagetype == ZIP: toml_table[CODE_URI_FIELD] = function_build_definition.codeuri toml_table[RUNTIME_FIELD] = function_build_definition.runtime - toml_table[SOURCE_MD5_FIELD] = function_build_definition.source_md5 toml_table[ARCHITECTURE_FIELD] = function_build_definition.architecture + if function_build_definition.source_hash: + toml_table[SOURCE_HASH_FIELD] = function_build_definition.source_hash + toml_table[MANIFEST_HASH_FIELD] = function_build_definition.manifest_hash toml_table[PACKAGETYPE_FIELD] = function_build_definition.packagetype toml_table[FUNCTIONS_FIELD] = [f.full_path for f in function_build_definition.functions] @@ -90,7 +98,8 @@ def _toml_table_to_function_build_definition(uuid: str, toml_table: tomlkit.api. toml_table.get(PACKAGETYPE_FIELD, ZIP), toml_table.get(ARCHITECTURE_FIELD, X86_64), dict(toml_table.get(METADATA_FIELD, {})), - toml_table.get(SOURCE_MD5_FIELD, ""), + toml_table.get(SOURCE_HASH_FIELD, ""), + toml_table.get(MANIFEST_HASH_FIELD, ""), dict(toml_table.get(ENV_VARS_FIELD, {})), ) function_build_definition.uuid = uuid @@ -116,9 +125,10 @@ def _layer_build_definition_to_toml_table(layer_build_definition: "LayerBuildDef toml_table[CODE_URI_FIELD] = layer_build_definition.codeuri toml_table[BUILD_METHOD_FIELD] = layer_build_definition.build_method toml_table[COMPATIBLE_RUNTIMES_FIELD] = layer_build_definition.compatible_runtimes - toml_table[SOURCE_MD5_FIELD] = layer_build_definition.source_md5 - toml_table[LAYER_FIELD] = layer_build_definition.layer.name toml_table[ARCHITECTURE_FIELD] = layer_build_definition.architecture + if layer_build_definition.source_hash: + toml_table[SOURCE_HASH_FIELD] = layer_build_definition.source_hash + toml_table[MANIFEST_HASH_FIELD] = layer_build_definition.manifest_hash if layer_build_definition.env_vars: toml_table[ENV_VARS_FIELD] = layer_build_definition.env_vars toml_table[LAYER_FIELD] = layer_build_definition.layer.full_path @@ -148,18 +158,31 @@ def _toml_table_to_layer_build_definition(uuid: str, toml_table: tomlkit.api.Tab toml_table.get(BUILD_METHOD_FIELD), toml_table.get(COMPATIBLE_RUNTIMES_FIELD), toml_table.get(ARCHITECTURE_FIELD, X86_64), - toml_table.get(SOURCE_MD5_FIELD, ""), + toml_table.get(SOURCE_HASH_FIELD, ""), + toml_table.get(MANIFEST_HASH_FIELD, ""), dict(toml_table.get(ENV_VARS_FIELD, {})), ) layer_build_definition.uuid = uuid return layer_build_definition +class BuildHashingInformation(NamedTuple): + """ + Holds hashing information for the source folder and the manifest file + """ + + source_hash: str + manifest_hash: str + + class BuildGraph: """ Contains list of build definitions, with ability to read and write them into build.toml file """ + # private lock for build.toml reads and writes + __toml_lock = threading.Lock() + # global table build definitions key FUNCTION_BUILD_DEFINITIONS = "function_build_definitions" LAYER_BUILD_DEFINITIONS = "layer_build_definitions" @@ -169,7 +192,7 @@ def __init__(self, build_dir: str) -> None: self._filepath = Path(build_dir).parent.joinpath(DEFAULT_BUILD_GRAPH_FILE_NAME) self._function_build_definitions: List["FunctionBuildDefinition"] = [] self._layer_build_definitions: List["LayerBuildDefinition"] = [] - self._read() + self._atomic_read() def get_function_build_definitions(self) -> Tuple["FunctionBuildDefinition", ...]: return tuple(self._function_build_definitions) @@ -177,6 +200,29 @@ def get_function_build_definitions(self) -> Tuple["FunctionBuildDefinition", ... def get_layer_build_definitions(self) -> Tuple["LayerBuildDefinition", ...]: return tuple(self._layer_build_definitions) + def get_function_build_definition_with_logical_id( + self, function_logial_id: str + ) -> Optional["FunctionBuildDefinition"]: + """ + Returns FunctionBuildDefinition instance of given function logical id. + + Parameters + ---------- + function_logial_id : str + Function logical id that will be searched in the function build definitions + + Returns + ------- + Optional[FunctionBuildDefinition] + If a function build definition found returns it, otherwise returns None + + """ + for function_build_definition in self._function_build_definitions: + for build_definition_function in function_build_definition.functions: + if build_definition_function.name == function_logial_id: + return function_build_definition + return None + def put_function_build_definition( self, function_build_definition: "FunctionBuildDefinition", function: Function ) -> None: @@ -261,7 +307,82 @@ def clean_redundant_definitions_and_update(self, persist: bool) -> None: ] self._layer_build_definitions[:] = [bd for bd in self._layer_build_definitions if bd.layer] if persist: - self._write() + self._atomic_write() + + def update_definition_hash(self) -> None: + """ + Updates the build.toml file with the newest source_hash values of the partial build's definitions + + This operation is atomic, that no other thread accesses build.toml + during the process of reading and modifying the hash value + """ + with BuildGraph.__toml_lock: + stored_definitions = copy.deepcopy(self._function_build_definitions) + stored_layers = copy.deepcopy(self._layer_build_definitions) + self._read() + + function_content = BuildGraph._compare_hash_changes(stored_definitions, self._function_build_definitions) + layer_content = BuildGraph._compare_hash_changes(stored_layers, self._layer_build_definitions) + + if function_content or layer_content: + self._write_source_hash(function_content, layer_content) + + @staticmethod + def _compare_hash_changes( + input_list: Sequence["AbstractBuildDefinition"], compared_list: Sequence["AbstractBuildDefinition"] + ) -> Dict[str, BuildHashingInformation]: + """ + Helper to compare the function and layer definition changes in hash value + + Returns a dictionary that has uuid as key, updated hash value as value + """ + content = {} + for compared_def in compared_list: + for stored_def in input_list: + if stored_def == compared_def: + old_hash = compared_def.source_hash + updated_hash = stored_def.source_hash + old_manifest_hash = compared_def.manifest_hash + updated_manifest_hash = stored_def.manifest_hash + uuid = stored_def.uuid + if old_hash != updated_hash or old_manifest_hash != updated_manifest_hash: + content[uuid] = BuildHashingInformation(updated_hash, updated_manifest_hash) + compared_def.download_dependencies = old_manifest_hash != updated_manifest_hash + return content + + def _write_source_hash( + self, function_content: Dict[str, BuildHashingInformation], layer_content: Dict[str, BuildHashingInformation] + ) -> None: + """ + Helper to write source_hash values to build.toml file + """ + document = {} + if not self._filepath.exists(): + open(self._filepath, "a+").close() + + txt = self._filepath.read_text() + # .loads() returns a TOMLDocument, + # and it behaves like a standard dictionary according to https://github.com/sdispater/tomlkit. + # in tomlkit 0.7.2, the types are broken (tomlkit#128, #130, #134) so here we convert it to Dict. + document = cast(Dict, tomlkit.loads(txt)) + + for function_uuid, hashing_info in function_content.items(): + if function_uuid in document.get(BuildGraph.FUNCTION_BUILD_DEFINITIONS, {}): + function_build_definition = document[BuildGraph.FUNCTION_BUILD_DEFINITIONS][function_uuid] + function_build_definition[SOURCE_HASH_FIELD] = hashing_info.source_hash + function_build_definition[MANIFEST_HASH_FIELD] = hashing_info.manifest_hash + LOG.info( + "Updated source_hash and manifest_hash field in build.toml for function with UUID %s", function_uuid + ) + + for layer_uuid, hashing_info in layer_content.items(): + if layer_uuid in document.get(BuildGraph.LAYER_BUILD_DEFINITIONS, {}): + layer_build_definition = document[BuildGraph.LAYER_BUILD_DEFINITIONS][layer_uuid] + layer_build_definition[SOURCE_HASH_FIELD] = hashing_info.source_hash + layer_build_definition[MANIFEST_HASH_FIELD] = hashing_info.manifest_hash + LOG.info("Updated source_hash and manifest_hash field in build.toml for layer with UUID %s", layer_uuid) + + self._filepath.write_text(tomlkit.dumps(document)) # type: ignore def _read(self) -> None: """ @@ -280,20 +401,29 @@ def _read(self) -> None: document = cast(Dict, tomlkit.loads(txt)) except OSError: LOG.debug("No previous build graph found, generating new one") - function_build_definitions_table = document.get(BuildGraph.FUNCTION_BUILD_DEFINITIONS, []) + function_build_definitions_table = document.get(BuildGraph.FUNCTION_BUILD_DEFINITIONS, {}) for function_build_definition_key in function_build_definitions_table: function_build_definition = _toml_table_to_function_build_definition( function_build_definition_key, function_build_definitions_table[function_build_definition_key] ) self._function_build_definitions.append(function_build_definition) - layer_build_definitions_table = document.get(BuildGraph.LAYER_BUILD_DEFINITIONS, []) + layer_build_definitions_table = document.get(BuildGraph.LAYER_BUILD_DEFINITIONS, {}) for layer_build_definition_key in layer_build_definitions_table: layer_build_definition = _toml_table_to_layer_build_definition( layer_build_definition_key, layer_build_definitions_table[layer_build_definition_key] ) self._layer_build_definitions.append(layer_build_definition) + def _atomic_read(self) -> None: + """ + Performs the _read() method with a global lock acquired + It makes sure no other thread accesses build.toml when a read is happening + """ + + with BuildGraph.__toml_lock: + self._read() + def _write(self) -> None: """ Writes build definition details into build.toml file, which would be used by the next build. @@ -324,6 +454,15 @@ def _write(self) -> None: self._filepath.write_text(tomlkit.dumps(document)) + def _atomic_write(self) -> None: + """ + Performs the _write() method with a global lock acquired + It makes sure no other thread accesses build.toml when a write is happening + """ + + with BuildGraph.__toml_lock: + self._write() + class AbstractBuildDefinition: """ @@ -331,11 +470,20 @@ class AbstractBuildDefinition: Build definition holds information about each unique build """ - def __init__(self, source_md5: str, env_vars: Optional[Dict] = None, architecture: str = X86_64) -> None: + def __init__( + self, source_hash: str, manifest_hash: str, env_vars: Optional[Dict] = None, architecture: str = X86_64 + ) -> None: self.uuid = str(uuid4()) - self.source_md5 = source_md5 + self.source_hash = source_hash + self.manifest_hash = manifest_hash self._env_vars = env_vars if env_vars else {} self.architecture = architecture + # following properties are used during build time and they don't serialize into build.toml file + self.download_dependencies: bool = True + + @property + def dependencies_dir(self) -> str: + return str(os.path.join(DEFAULT_DEPENDENCIES_DIR, self.uuid)) @property def env_vars(self) -> Dict: @@ -354,10 +502,11 @@ def __init__( build_method: Optional[str], compatible_runtimes: Optional[List[str]], architecture: str, - source_md5: str = "", + source_hash: str = "", + manifest_hash: str = "", env_vars: Optional[Dict] = None, ): - super().__init__(source_md5, env_vars, architecture) + super().__init__(source_hash, manifest_hash, env_vars, architecture) self.name = name self.codeuri = codeuri self.build_method = build_method @@ -368,7 +517,7 @@ def __init__( def __str__(self) -> str: return ( - f"LayerBuildDefinition({self.name}, {self.codeuri}, {self.source_md5}, {self.uuid}, " + f"LayerBuildDefinition({self.name}, {self.codeuri}, {self.source_hash}, {self.uuid}, " f"{self.build_method}, {self.compatible_runtimes}, {self.architecture}, {self.env_vars})" ) @@ -411,10 +560,11 @@ def __init__( packagetype: str, architecture: str, metadata: Optional[Dict], - source_md5: str = "", + source_hash: str = "", + manifest_hash: str = "", env_vars: Optional[Dict] = None, ) -> None: - super().__init__(source_md5, env_vars, architecture) + super().__init__(source_hash, manifest_hash, env_vars, architecture) self.runtime = runtime self.codeuri = codeuri self.packagetype = packagetype @@ -453,8 +603,8 @@ def _validate_functions(self) -> None: def __str__(self) -> str: return ( "BuildDefinition(" - f"{self.runtime}, {self.codeuri}, {self.packagetype}, {self.architecture}, " - f"{self.source_md5}, {self.uuid}, {self.metadata}, {self.env_vars}, " + f"{self.runtime}, {self.codeuri}, {self.packagetype}, {self.source_hash}, " + f"{self.uuid}, {self.metadata}, {self.env_vars}, {self.architecture}, " f"{[f.functionname for f in self.functions]})" ) diff --git a/samcli/lib/build/build_strategy.py b/samcli/lib/build/build_strategy.py index e04e6058df..c12cc9b40d 100644 --- a/samcli/lib/build/build_strategy.py +++ b/samcli/lib/build/build_strategy.py @@ -1,23 +1,54 @@ """ Keeps implementation of different build strategies """ +import hashlib import logging import pathlib import shutil from abc import abstractmethod, ABC -from typing import Callable, Dict, List, Any, Optional, cast +from copy import deepcopy +from typing import Callable, Dict, List, Any, Optional, cast, Set -from samcli.commands.build.exceptions import MissingBuildMethodException +from samcli.commands._utils.experimental import is_experimental_enabled, ExperimentalFlag from samcli.lib.utils import osutils from samcli.lib.utils.async_utils import AsyncContext from samcli.lib.utils.hash import dir_checksum from samcli.lib.utils.packagetype import ZIP, IMAGE -from samcli.lib.build.build_graph import BuildGraph, FunctionBuildDefinition, LayerBuildDefinition +from samcli.lib.build.dependency_hash_generator import DependencyHashGenerator +from samcli.lib.build.build_graph import ( + BuildGraph, + FunctionBuildDefinition, + LayerBuildDefinition, + AbstractBuildDefinition, + DEFAULT_DEPENDENCIES_DIR, +) +from samcli.lib.build.exceptions import MissingBuildMethodException LOG = logging.getLogger(__name__) +def clean_redundant_folders(base_dir: str, uuids: Set[str]) -> None: + """ + Compares existing folders inside base_dir and removes the ones which is not in the uuids set. + + Parameters + ---------- + base_dir : str + Base directory that it will be operating + uuids : Set[str] + Expected folder names. If any folder name in the base_dir is not present in this Set, it will be deleted. + """ + base_dir_path = pathlib.Path(base_dir) + + if not base_dir_path.exists(): + return + + for full_dir_path in pathlib.Path(base_dir).iterdir(): + if full_dir_path.name not in uuids: + shutil.rmtree(pathlib.Path(base_dir, full_dir_path.name)) + + class BuildStrategy(ABC): """ Base class for BuildStrategy @@ -88,8 +119,8 @@ def __init__( self, build_graph: BuildGraph, build_dir: str, - build_function: Callable[[str, str, str, str, str, Optional[str], str, dict, dict], str], - build_layer: Callable[[str, str, str, List[str], str, str, dict], str], + build_function: Callable[[str, str, str, str, str, Optional[str], str, dict, dict, Optional[str], bool], str], + build_layer: Callable[[str, str, str, List[str], str, str, dict, Optional[str], bool], str], ) -> None: super().__init__(build_graph) self._build_dir = build_dir @@ -116,6 +147,10 @@ def build_single_function_definition(self, build_definition: FunctionBuildDefini LOG.debug("Building to following folder %s", single_build_dir) + # we should create a copy and pass it down, otherwise additional env vars like LAMBDA_BUILDERS_LOG_LEVEL + # will make cache invalid all the time + container_env_vars = deepcopy(build_definition.env_vars) + # when a function is passed here, it is ZIP function, codeuri and runtime are not None result = self._build_function( build_definition.get_function_name(), @@ -126,7 +161,9 @@ def build_single_function_definition(self, build_definition: FunctionBuildDefini build_definition.get_handler_name(), single_build_dir, build_definition.metadata, - build_definition.env_vars, + container_env_vars, + build_definition.dependencies_dir, + build_definition.download_dependencies, ) function_build_results[single_full_path] = result @@ -172,6 +209,8 @@ def build_single_layer_definition(self, layer_definition: LayerBuildDefinition) layer.build_architecture, single_build_dir, layer_definition.env_vars, + layer_definition.dependencies_dir, + layer_definition.download_dependencies, ) } @@ -180,7 +219,7 @@ class CachedBuildStrategy(BuildStrategy): """ Cached implementation of Build Strategy For each function and layer, it first checks if there is a valid cache, and if there is, it copies from previous - build. If caching is invalid, it builds function or layer from scratch and updates cache folder and md5 of the + build. If caching is invalid, it builds function or layer from scratch and updates cache folder and hash of the function or layer. For actual building, it uses delegate implementation """ @@ -192,21 +231,16 @@ def __init__( base_dir: str, build_dir: str, cache_dir: str, - is_building_specific_resource: bool, ) -> None: super().__init__(build_graph) self._delegate_build_strategy = delegate_build_strategy self._base_dir = base_dir self._build_dir = build_dir self._cache_dir = cache_dir - self._is_building_specific_resource = is_building_specific_resource - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - self._clean_redundant_cached() def build(self) -> Dict[str, str]: result = {} - with self, self._delegate_build_strategy: + with self._delegate_build_strategy: result.update(super().build()) return result @@ -218,11 +252,11 @@ def build_single_function_definition(self, build_definition: FunctionBuildDefini return self._delegate_build_strategy.build_single_function_definition(build_definition) code_dir = str(pathlib.Path(self._base_dir, cast(str, build_definition.codeuri)).resolve()) - source_md5 = dir_checksum(code_dir, ignore_list=[".aws-sam"]) + source_hash = dir_checksum(code_dir, ignore_list=[".aws-sam"], hash_generator=hashlib.sha256()) cache_function_dir = pathlib.Path(self._cache_dir, build_definition.uuid) function_build_results = {} - if not cache_function_dir.exists() or build_definition.source_md5 != source_md5: + if not cache_function_dir.exists() or build_definition.source_hash != source_hash: LOG.info( "Cache is invalid, running build and copying resources to function build definition of %s", build_definition.uuid, @@ -233,7 +267,7 @@ def build_single_function_definition(self, build_definition: FunctionBuildDefini if cache_function_dir.exists(): shutil.rmtree(str(cache_function_dir)) - build_definition.source_md5 = source_md5 + build_definition.source_hash = source_hash # Since all the build contents are same for a build definition, just copy any one of them into the cache for _, value in build_result.items(): osutils.copytree(value, cache_function_dir) @@ -257,11 +291,11 @@ def build_single_layer_definition(self, layer_definition: LayerBuildDefinition) Builds single layer definition with caching """ code_dir = str(pathlib.Path(self._base_dir, cast(str, layer_definition.codeuri)).resolve()) - source_md5 = dir_checksum(code_dir, ignore_list=[".aws-sam"]) + source_hash = dir_checksum(code_dir, ignore_list=[".aws-sam"], hash_generator=hashlib.sha256()) cache_function_dir = pathlib.Path(self._cache_dir, layer_definition.uuid) layer_build_result = {} - if not cache_function_dir.exists() or layer_definition.source_md5 != source_md5: + if not cache_function_dir.exists() or layer_definition.source_hash != source_hash: LOG.info( "Cache is invalid, running build and copying resources to layer build definition of %s", layer_definition.uuid, @@ -272,7 +306,7 @@ def build_single_layer_definition(self, layer_definition: LayerBuildDefinition) if cache_function_dir.exists(): shutil.rmtree(str(cache_function_dir)) - layer_definition.source_md5 = source_md5 + layer_definition.source_hash = source_hash # Since all the build contents are same for a build definition, just copy any one of them into the cache for _, value in build_result.items(): osutils.copytree(value, cache_function_dir) @@ -294,12 +328,9 @@ def _clean_redundant_cached(self) -> None: """ clean the redundant cached folder """ - self._build_graph.clean_redundant_definitions_and_update(not self._is_building_specific_resource) uuids = {bd.uuid for bd in self._build_graph.get_function_build_definitions()} uuids.update({ld.uuid for ld in self._build_graph.get_layer_build_definitions()}) - for cache_dir in pathlib.Path(self._cache_dir).iterdir(): - if cache_dir.name not in uuids: - shutil.rmtree(pathlib.Path(self._cache_dir, cache_dir.name)) + clean_redundant_folders(self._cache_dir, uuids) class ParallelBuildStrategy(BuildStrategy): @@ -313,18 +344,18 @@ def __init__( self, build_graph: BuildGraph, delegate_build_strategy: BuildStrategy, - async_context: AsyncContext = AsyncContext(), + async_context: Optional[AsyncContext] = None, ) -> None: super().__init__(build_graph) self._delegate_build_strategy = delegate_build_strategy - self._async_context = async_context + self._async_context = async_context if async_context else AsyncContext() def build(self) -> Dict[str, str]: """ Runs all build and collects results from async context """ result = {} - with self, self._delegate_build_strategy: + with self._delegate_build_strategy: # ignore result super().build() # wait for other executions to complete @@ -352,3 +383,181 @@ def build_single_layer_definition(self, layer_definition: LayerBuildDefinition) self._delegate_build_strategy.build_single_layer_definition, layer_definition ) return {} + + +class IncrementalBuildStrategy(BuildStrategy): + """ + Incremental build is supported for certain runtimes in aws-lambda-builders, with dependencies_dir (str) + and download_dependencies (bool) options. + + This build strategy sets whether we need to download dependencies again (download_dependencies option) by comparing + the hash of the manifest file of the given runtime as well as the dependencies directory location + (dependencies_dir option). + """ + + def __init__( + self, + build_graph: BuildGraph, + delegate_build_strategy: BuildStrategy, + base_dir: str, + manifest_path_override: Optional[str], + ): + super().__init__(build_graph) + self._delegate_build_strategy = delegate_build_strategy + self._base_dir = base_dir + self._manifest_path_override = manifest_path_override + + def build(self) -> Dict[str, str]: + result = {} + with self, self._delegate_build_strategy: + result.update(super().build()) + return result + + def build_single_function_definition(self, build_definition: FunctionBuildDefinition) -> Dict[str, str]: + self._check_whether_manifest_is_changed(build_definition, build_definition.codeuri, build_definition.runtime) + return self._delegate_build_strategy.build_single_function_definition(build_definition) + + def build_single_layer_definition(self, layer_definition: LayerBuildDefinition) -> Dict[str, str]: + self._check_whether_manifest_is_changed( + layer_definition, layer_definition.codeuri, layer_definition.build_method + ) + return self._delegate_build_strategy.build_single_layer_definition(layer_definition) + + def _check_whether_manifest_is_changed( + self, + build_definition: AbstractBuildDefinition, + codeuri: Optional[str], + runtime: Optional[str], + ) -> None: + """ + Checks whether the manifest file have been changed by comparing its hash with previously stored one and updates + download_dependencies property of build definition to True, if it is changed + """ + manifest_hash = DependencyHashGenerator( + cast(str, codeuri), self._base_dir, cast(str, runtime), self._manifest_path_override + ).hash + + is_manifest_changed = True + if manifest_hash: + is_manifest_changed = manifest_hash != build_definition.manifest_hash + if is_manifest_changed: + build_definition.manifest_hash = manifest_hash + LOG.info( + "Manifest is changed for %s, downloading dependencies and copying/building source", + build_definition.uuid, + ) + else: + LOG.info("Manifest is not changed for %s, running incremental build", build_definition.uuid) + + build_definition.download_dependencies = is_manifest_changed + + def _clean_redundant_dependencies(self) -> None: + """ + Update build definitions with possible new manifest hash information and clean the redundant dependencies folder + """ + uuids = {bd.uuid for bd in self._build_graph.get_function_build_definitions()} + uuids.update({ld.uuid for ld in self._build_graph.get_layer_build_definitions()}) + clean_redundant_folders(DEFAULT_DEPENDENCIES_DIR, uuids) + + +class CachedOrIncrementalBuildStrategyWrapper(BuildStrategy): + """ + A wrapper class which holds instance of CachedBuildStrategy and IncrementalBuildStrategy + to select one of them during function or layer build, depending on the runtime that they are using + """ + + SUPPORTED_RUNTIME_PREFIXES: Set[str] = { + "python", + "ruby", + "nodejs", + } + + def __init__( + self, + build_graph: BuildGraph, + delegate_build_strategy: BuildStrategy, + base_dir: str, + build_dir: str, + cache_dir: str, + manifest_path_override: Optional[str], + is_building_specific_resource: bool, + ): + super().__init__(build_graph) + self._incremental_build_strategy = IncrementalBuildStrategy( + build_graph, + delegate_build_strategy, + base_dir, + manifest_path_override, + ) + self._cached_build_strategy = CachedBuildStrategy( + build_graph, + delegate_build_strategy, + base_dir, + build_dir, + cache_dir, + ) + self._is_building_specific_resource = is_building_specific_resource + + def build(self) -> Dict[str, str]: + result = {} + with self._cached_build_strategy, self._incremental_build_strategy: + result.update(super().build()) + return result + + def build_single_function_definition(self, build_definition: FunctionBuildDefinition) -> Dict[str, str]: + if self._is_incremental_build_supported(build_definition.runtime): + LOG.debug( + "Running incremental build for runtime %s for build definition %s", + build_definition.runtime, + build_definition.uuid, + ) + return self._incremental_build_strategy.build_single_function_definition(build_definition) + + LOG.debug( + "Running incremental build for runtime %s for build definition %s", + build_definition.runtime, + build_definition.uuid, + ) + return self._cached_build_strategy.build_single_function_definition(build_definition) + + def build_single_layer_definition(self, layer_definition: LayerBuildDefinition) -> Dict[str, str]: + if self._is_incremental_build_supported(layer_definition.build_method): + LOG.debug( + "Running incremental build for runtime %s for build definition %s", + layer_definition.build_method, + layer_definition.uuid, + ) + return self._incremental_build_strategy.build_single_layer_definition(layer_definition) + + LOG.debug( + "Running cached build for runtime %s for build definition %s", + layer_definition.build_method, + layer_definition.uuid, + ) + return self._cached_build_strategy.build_single_layer_definition(layer_definition) + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """ + After build is complete, this method cleans up redundant folders in cached directory as well as in dependencies + directory. This also updates hashes of the functions and layers, if only single function or layer is been built. + + If SAM CLI switched to use only IncrementalBuildStrategy, contents of this method should be moved inside + IncrementalBuildStrategy so that it will still continue to clean-up redundant folders. + """ + if self._is_building_specific_resource: + self._build_graph.update_definition_hash() + else: + self._build_graph.clean_redundant_definitions_and_update(not self._is_building_specific_resource) + self._cached_build_strategy._clean_redundant_cached() + self._incremental_build_strategy._clean_redundant_dependencies() + + @staticmethod + def _is_incremental_build_supported(runtime: Optional[str]) -> bool: + if not runtime or not is_experimental_enabled(ExperimentalFlag.Accelerate): + return False + + for supported_runtime_prefix in CachedOrIncrementalBuildStrategyWrapper.SUPPORTED_RUNTIME_PREFIXES: + if runtime.startswith(supported_runtime_prefix): + return True + + return False diff --git a/samcli/lib/build/dependency_hash_generator.py b/samcli/lib/build/dependency_hash_generator.py new file mode 100644 index 0000000000..d92b1460a8 --- /dev/null +++ b/samcli/lib/build/dependency_hash_generator.py @@ -0,0 +1,86 @@ +"""Utility Class for Getting Function or Layer Manifest Dependency Hashes""" +import pathlib + +from typing import Any, Optional + +from samcli.lib.build.workflow_config import get_workflow_config +from samcli.lib.utils.hash import file_checksum + +# TODO Expand this class to hash specific sections of the manifest +class DependencyHashGenerator: + _code_uri: str + _base_dir: str + _code_dir: str + _runtime: str + _manifest_path_override: Optional[str] + _hash_generator: Any + _calculated: bool + _hash: Optional[str] + + def __init__( + self, + code_uri: str, + base_dir: str, + runtime: str, + manifest_path_override: Optional[str] = None, + hash_generator: Any = None, + ): + """ + Parameters + ---------- + code_uri : str + Relative path specified in the function/layer resource + base_dir : str + Absolute path which the function/layer dir is located + runtime : str + Runtime of the function/layer + manifest_path_override : Optional[str], optional + Override default manifest path for each runtime, by default None + hash_generator : Any, optional + Hash generation function. Can be hashlib.md5(), hashlib.sha256(), etc, by default None + """ + self._code_uri = code_uri + self._base_dir = base_dir + self._code_dir = str(pathlib.Path(self._base_dir, self._code_uri).resolve()) + self._runtime = runtime + self._manifest_path_override = manifest_path_override + self._hash_generator = hash_generator + self._calculated = False + self._hash = None + + def _calculate_dependency_hash(self) -> Optional[str]: + """Calculate the manifest file hash + + Returns + ------- + Optional[str] + Returns manifest hash. If manifest does not exist or not supported, None will be returned. + """ + if self._manifest_path_override: + manifest_file = self._manifest_path_override + else: + config = get_workflow_config(self._runtime, self._code_dir, self._base_dir) + manifest_file = config.manifest_name + + if not manifest_file: + return None + + manifest_path = pathlib.Path(self._code_dir, manifest_file).resolve() + if not manifest_path.is_file(): + return None + + return file_checksum(str(manifest_path), hash_generator=self._hash_generator) + + @property + def hash(self) -> Optional[str]: + """ + Returns + ------- + Optional[str] + Hash for dependencies in the manifest. + If the manifest does not exist or not supported, this value will be None. + """ + if not self._calculated: + self._hash = self._calculate_dependency_hash() + self._calculated = True + return self._hash diff --git a/samcli/lib/build/exceptions.py b/samcli/lib/build/exceptions.py index 7b5fc265d4..321302c677 100644 --- a/samcli/lib/build/exceptions.py +++ b/samcli/lib/build/exceptions.py @@ -41,6 +41,11 @@ def __init__(self, msg: str) -> None: BuildError.__init__(self, "DockerBuildFailed", msg) +class MissingBuildMethodException(BuildError): + def __init__(self, msg: str) -> None: + BuildError.__init__(self, "MissingBuildMethodException", msg) + + class InvalidBuildGraphException(Exception): def __init__(self, msg: str) -> None: Exception.__init__(self, msg) diff --git a/samcli/lib/deploy/deployer.py b/samcli/lib/deploy/deployer.py index ec73af71d7..97a8e7c32d 100644 --- a/samcli/lib/deploy/deployer.py +++ b/samcli/lib/deploy/deployer.py @@ -21,7 +21,7 @@ import logging import time from datetime import datetime -from typing import Dict, List +from typing import Dict, List, Optional import botocore @@ -53,7 +53,7 @@ } ) -DESCRIBE_STACK_EVENTS_TABLE_HEADER_NAME = "CloudFormation events from changeset" +DESCRIBE_STACK_EVENTS_TABLE_HEADER_NAME = "CloudFormation events from stack operations" DESCRIBE_CHANGESET_FORMAT_STRING = "{Operation:<{0}} {LogicalResourceId:<{1}} {ResourceType:<{2}} {Replacement:<{3}}" DESCRIBE_CHANGESET_DEFAULT_ARGS = OrderedDict( @@ -174,6 +174,17 @@ def create_changeset( "Tags": tags, } + kwargs = self._process_kwargs(kwargs, s3_uploader, capabilities, role_arn, notification_arns) + return self._create_change_set(stack_name=stack_name, changeset_type=changeset_type, **kwargs) + + @staticmethod + def _process_kwargs( + kwargs: dict, + s3_uploader: Optional[S3Uploader], + capabilities: Optional[List[str]], + role_arn: Optional[str], + notification_arns: Optional[List[str]], + ) -> dict: # If an S3 uploader is available, use TemplateURL to deploy rather than # TemplateBody. This is required for large templates. if s3_uploader: @@ -194,7 +205,7 @@ def create_changeset( kwargs["RoleARN"] = role_arn if notification_arns is not None: kwargs["NotificationARNs"] = notification_arns - return self._create_change_set(stack_name=stack_name, changeset_type=changeset_type, **kwargs) + return kwargs def _create_change_set(self, stack_name, changeset_type, **kwargs): try: @@ -410,17 +421,17 @@ def describe_stack_events(self, stack_name, time_stamp_marker, **kwargs): def _check_stack_not_in_progress(status: str) -> bool: return "IN_PROGRESS" not in status - def wait_for_execute(self, stack_name, changeset_type, disable_rollback): + def wait_for_execute(self, stack_name: str, stack_operation: str, disable_rollback: bool) -> None: """ - Wait for changeset to execute and return when execution completes. + Wait for stack operation to execute and return when execution completes. If the stack has "Outputs," they will be printed. Parameters ---------- stack_name : str The name of the stack - changeset_type : str - The type of the changeset, 'CREATE' or 'UPDATE' + stack_operation : str + The type of the stack operation, 'CREATE' or 'UPDATE' disable_rollback : bool Preserves the state of previously provisioned resources when an operation fails """ @@ -433,12 +444,12 @@ def wait_for_execute(self, stack_name, changeset_type, disable_rollback): self.describe_stack_events(stack_name, self.get_last_event_time(stack_name)) # Pick the right waiter - if changeset_type == "CREATE": + if stack_operation == "CREATE": waiter = self._client.get_waiter("stack_create_complete") - elif changeset_type == "UPDATE": + elif stack_operation == "UPDATE": waiter = self._client.get_waiter("stack_update_complete") else: - raise RuntimeError("Invalid changeset type {0}".format(changeset_type)) + raise RuntimeError("Invalid stack operation type {0}".format(stack_operation)) # Poll every 30 seconds. Polling too frequently risks hitting rate limits # on CloudFormation's DescribeStacks API @@ -447,7 +458,7 @@ def wait_for_execute(self, stack_name, changeset_type, disable_rollback): try: waiter.wait(StackName=stack_name, WaiterConfig=waiter_config) except botocore.exceptions.WaiterError as ex: - LOG.debug("Execute changeset waiter exception", exc_info=ex) + LOG.debug("Execute stack waiter exception", exc_info=ex) if disable_rollback: msg = self._gen_deploy_failed_with_rollback_disabled_msg(stack_name) LOG.info(self._colored.red(msg)) @@ -471,6 +482,99 @@ def create_and_wait_for_changeset( except botocore.exceptions.ClientError as ex: raise DeployFailedError(stack_name=stack_name, msg=str(ex)) from ex + def create_stack(self, **kwargs): + stack_name = kwargs.get("StackName") + try: + resp = self._client.create_stack(**kwargs) + return resp + except botocore.exceptions.ClientError as ex: + if "The bucket you are attempting to access must be addressed using the specified endpoint" in str(ex): + raise DeployBucketInDifferentRegionError(f"Failed to create/update stack {stack_name}") from ex + raise DeployFailedError(stack_name=stack_name, msg=str(ex)) from ex + + except Exception as ex: + LOG.debug("Unable to create stack", exc_info=ex) + raise DeployFailedError(stack_name=stack_name, msg=str(ex)) from ex + + def update_stack(self, **kwargs): + stack_name = kwargs.get("StackName") + try: + resp = self._client.update_stack(**kwargs) + return resp + except botocore.exceptions.ClientError as ex: + if "The bucket you are attempting to access must be addressed using the specified endpoint" in str(ex): + raise DeployBucketInDifferentRegionError(f"Failed to create/update stack {stack_name}") from ex + raise DeployFailedError(stack_name=stack_name, msg=str(ex)) from ex + + except Exception as ex: + LOG.debug("Unable to update stack", exc_info=ex) + raise DeployFailedError(stack_name=stack_name, msg=str(ex)) from ex + + def sync( + self, + stack_name: str, + cfn_template: str, + parameter_values: List[Dict], + capabilities: Optional[List[str]], + role_arn: Optional[str], + notification_arns: Optional[List[str]], + s3_uploader: Optional[S3Uploader], + tags: Optional[Dict], + ): + """ + Call the sync command to directly update stack or create stack + + Parameters + ---------- + :param stack_name: The name of the stack + :param cfn_template: CloudFormation template string + :param parameter_values: Template parameters object + :param capabilities: Array of capabilities passed to CloudFormation + :param role_arn: the Arn of the role to create changeset + :param notification_arns: Arns for sending notifications + :param s3_uploader: S3Uploader object to upload files to S3 buckets + :param tags: Array of tags passed to CloudFormation + :return: + """ + exists = self.has_stack(stack_name) + + if not exists: + # When creating a new stack, UsePreviousValue=True is invalid. + # For such parameters, users should either override with new value, + # or set a Default value in template to successfully create a stack. + parameter_values = [x for x in parameter_values if not x.get("UsePreviousValue", False)] + else: + summary = self._client.get_template_summary(StackName=stack_name) + existing_parameters = [parameter["ParameterKey"] for parameter in summary["Parameters"]] + parameter_values = [ + x + for x in parameter_values + if not (x.get("UsePreviousValue", False) and x["ParameterKey"] not in existing_parameters) + ] + + kwargs = { + "StackName": stack_name, + "TemplateBody": cfn_template, + "Parameters": parameter_values, + "Tags": tags, + } + + kwargs = self._process_kwargs(kwargs, s3_uploader, capabilities, role_arn, notification_arns) + + try: + if exists: + result = self.update_stack(**kwargs) + self.wait_for_execute(stack_name, "UPDATE", False) + LOG.info("\nStack update succeeded. Sync infra completed.\n") + else: + result = self.create_stack(**kwargs) + self.wait_for_execute(stack_name, "CREATE", False) + LOG.info("\nStack creation succeeded. Sync infra completed.\n") + + return result + except botocore.exceptions.ClientError as ex: + raise DeployFailedError(stack_name=stack_name, msg=str(ex)) from ex + @staticmethod @pprint_column_names( format_string=OUTPUTS_FORMAT_STRING, format_kwargs=OUTPUTS_DEFAULTS_ARGS, table_header=OUTPUTS_TABLE_HEADER_NAME diff --git a/samcli/lib/iac/cfn/cfn_iac.py b/samcli/lib/iac/cfn/cfn_iac.py index cd069c1bd1..141be0e8de 100644 --- a/samcli/lib/iac/cfn/cfn_iac.py +++ b/samcli/lib/iac/cfn/cfn_iac.py @@ -24,11 +24,15 @@ LookupPath, ) from samcli.lib.providers.sam_base_provider import SamBaseProvider -from samcli.commands._utils.resources import ( +from samcli.lib.utils.resources import ( METADATA_WITH_LOCAL_PATHS, RESOURCES_WITH_IMAGE_COMPONENT, RESOURCES_WITH_LOCAL_PATHS, NESTED_STACKS_RESOURCES, + AWS_SERVERLESS_FUNCTION, + AWS_LAMBDA_FUNCTION, + AWS_SERVERLESS_LAYERVERSION, + AWS_LAMBDA_LAYERVERSION, ) from samcli.commands._utils.template import get_template_data from samcli.lib.providers.sam_stack_provider import SamLocalStackProvider, is_local_path, get_local_path @@ -40,10 +44,10 @@ TEMPLATE_BUILD_PATH_KEY = "template_build_path" BASE_DIR_RESOURCES = [ - SamLocalStackProvider.SERVERLESS_FUNCTION, - SamLocalStackProvider.LAMBDA_FUNCTION, - SamLocalStackProvider.SERVERLESS_LAYER, - SamLocalStackProvider.LAMBDA_LAYER, + AWS_SERVERLESS_FUNCTION, + AWS_LAMBDA_FUNCTION, + AWS_SERVERLESS_LAYERVERSION, + AWS_LAMBDA_LAYERVERSION, ] diff --git a/samcli/lib/observability/cw_logs/cw_log_formatters.py b/samcli/lib/observability/cw_logs/cw_log_formatters.py index f0d35a18a6..63b2ffd983 100644 --- a/samcli/lib/observability/cw_logs/cw_log_formatters.py +++ b/samcli/lib/observability/cw_logs/cw_log_formatters.py @@ -4,6 +4,7 @@ import json import logging from json import JSONDecodeError +from typing import Any from samcli.lib.observability.cw_logs.cw_log_event import CWLogEvent from samcli.lib.observability.observability_info_puller import ObservabilityEventMapper @@ -92,3 +93,31 @@ def map(self, event: CWLogEvent) -> CWLogEvent: log_stream_name = self._colored.cyan(event.log_stream_name) event.message = f"{log_stream_name} {timestamp} {event.message}" return event + + +class CWAddNewLineIfItDoesntExist(ObservabilityEventMapper): + """ + Mapper implementation which will add new lines at the end of events if it is not already there + """ + + def map(self, event: Any) -> Any: + # if it is a CWLogEvent, append new line at the end of event.message + if isinstance(event, CWLogEvent) and not event.message.endswith("\n"): + event.message = f"{event.message}\n" + return event + # if event is a string, then append new line at the end of the string + if isinstance(event, str) and not event.endswith("\n"): + return f"{event}\n" + # no-action for unknown events + return event + + +class CWLogEventJSONMapper(ObservabilityEventMapper[CWLogEvent]): + """ + Converts given CWLogEvent into JSON string + """ + + # pylint: disable=no-self-use + def map(self, event: CWLogEvent) -> CWLogEvent: + event.message = json.dumps(event.event) + return event diff --git a/samcli/lib/observability/cw_logs/cw_log_group_provider.py b/samcli/lib/observability/cw_logs/cw_log_group_provider.py index 90893e5238..0bb8e4a8e0 100644 --- a/samcli/lib/observability/cw_logs/cw_log_group_provider.py +++ b/samcli/lib/observability/cw_logs/cw_log_group_provider.py @@ -1,6 +1,19 @@ """ Discover & provide the log group name """ +import logging +from typing import Optional + +from samcli.commands._utils.experimental import force_experimental, ExperimentalFlag +from samcli.lib.utils.resources import ( + AWS_LAMBDA_FUNCTION, + AWS_APIGATEWAY_RESTAPI, + AWS_APIGATEWAY_V2_API, + AWS_STEPFUNCTIONS_STATEMACHINE, +) +from samcli.lib.utils.boto_utils import BotoProviderType + +LOG = logging.getLogger(__name__) class LogGroupProvider: @@ -9,7 +22,21 @@ class LogGroupProvider: """ @staticmethod - def for_lambda_function(function_name): + def for_resource(boto_client_provider: BotoProviderType, resource_type: str, name: str) -> Optional[str]: + log_group = None + if resource_type == AWS_LAMBDA_FUNCTION: + log_group = LogGroupProvider.for_lambda_function(name) + elif resource_type == AWS_APIGATEWAY_RESTAPI: + log_group = LogGroupProvider.for_apigw_rest_api(name) + elif resource_type == AWS_APIGATEWAY_V2_API: + log_group = LogGroupProvider.for_apigwv2_http_api(boto_client_provider, name) + elif resource_type == AWS_STEPFUNCTIONS_STATEMACHINE: + log_group = LogGroupProvider.for_step_functions(boto_client_provider, name) + + return log_group + + @staticmethod + def for_lambda_function(function_name: str) -> str: """ Returns the CloudWatch Log Group Name created by default for the AWS Lambda function with given name @@ -24,3 +51,102 @@ def for_lambda_function(function_name): Default Log Group name used by this function """ return "/aws/lambda/{}".format(function_name) + + @staticmethod + @force_experimental(config_entry=ExperimentalFlag.Accelerate) # pylint: disable=E1120 + def for_apigw_rest_api(rest_api_id: str, stage: str = "Prod") -> str: + """ + Returns the CloudWatch Log Group Name created by default for the AWS Api gateway rest api with given id + + Parameters + ---------- + rest_api_id : str + Id of the rest api + stage: str + Stage of the rest api (the default value is "Prod") + + Returns + ------- + str + Default Log Group name used by this rest api + """ + + # TODO: A rest api may have multiple stage, here just log out the prod stage and can be extended to log out + # all stages or a specific stage if needed. + return "API-Gateway-Execution-Logs_{}/{}".format(rest_api_id, stage) + + @staticmethod + @force_experimental(config_entry=ExperimentalFlag.Accelerate) # pylint: disable=E1120 + def for_apigwv2_http_api( + boto_client_provider: BotoProviderType, http_api_id: str, stage: str = "$default" + ) -> Optional[str]: + """ + Returns the CloudWatch Log Group Name created by default for the AWS Api gatewayv2 http api with given id + + Parameters + ---------- + boto_client_provider: BotoProviderType + Boto client provider which contains region and other configurations + http_api_id : str + Id of the http api + stage: str + Stage of the rest api (the default value is "$default") + + Returns + ------- + str + Default Log Group name used by this http api + """ + apigw2_client = boto_client_provider("apigatewayv2") + + # TODO: A http api may have multiple stage, here just log out the default stage and can be extended to log out + # all stages or a specific stage if needed. + stage_info = apigw2_client.get_stage(ApiId=http_api_id, StageName=stage) + log_setting = stage_info.get("AccessLogSettings", None) + if not log_setting: + LOG.warning("Access logging is disabled for HTTP API ID (%s)", http_api_id) + return None + log_group_name = str(log_setting.get("DestinationArn").split(":")[-1]) + return log_group_name + + @staticmethod + @force_experimental(config_entry=ExperimentalFlag.Accelerate) # pylint: disable=E1120 + def for_step_functions( + boto_client_provider: BotoProviderType, + step_function_name: str, + ) -> Optional[str]: + """ + Calls describe_state_machine API to get details of the State Machine, + then extracts logging information to find the configured CW log group. + If nothing is configured it will return None + + Parameters + ---------- + boto_client_provider : BotoProviderType + Boto client provider which contains region and other configurations + step_function_name : str + Name of the step functions resource + + Returns + ------- + CW log group name if logging is configured, None otherwise + """ + sfn_client = boto_client_provider("stepfunctions") + + state_machine_info = sfn_client.describe_state_machine(stateMachineArn=step_function_name) + LOG.debug("State machine info: %s", state_machine_info) + + logging_destinations = state_machine_info.get("loggingConfiguration", {}).get("destinations", []) + LOG.debug("State Machine logging destinations: %s", logging_destinations) + + # users may configure multiple log groups to send state machine logs, find one and return it + for logging_destination in logging_destinations: + log_group_arn = logging_destination.get("cloudWatchLogsLogGroup", {}).get("logGroupArn") + LOG.debug("Log group ARN: %s", log_group_arn) + if log_group_arn: + log_group_arn_parts = log_group_arn.split(":") + log_group_name = log_group_arn_parts[6] + return str(log_group_name) + LOG.warning("Logging is not configured for StepFunctions (%s)") + + return None diff --git a/samcli/lib/observability/cw_logs/cw_log_puller.py b/samcli/lib/observability/cw_logs/cw_log_puller.py index e7d8b7fb10..e990860370 100644 --- a/samcli/lib/observability/cw_logs/cw_log_puller.py +++ b/samcli/lib/observability/cw_logs/cw_log_puller.py @@ -4,7 +4,9 @@ import logging import time from datetime import datetime -from typing import Optional, Any +from typing import Optional, Any, List + +from botocore.exceptions import ClientError from samcli.lib.observability.cw_logs.cw_log_event import CWLogEvent from samcli.lib.observability.observability_info_puller import ObservabilityPuller, ObservabilityEventConsumer @@ -30,7 +32,7 @@ def __init__( """ Parameters ---------- - logs_client: Any + logs_client: CloudWatchLogsClient boto3 logs client instance consumer : ObservabilityEventConsumer Consumer instance that will process pulled events @@ -51,17 +53,37 @@ def __init__( self._poll_interval = poll_interval self.latest_event_time = 0 self.had_data = False + self._invalid_log_group = False def tail(self, start_time: Optional[datetime] = None, filter_pattern: Optional[str] = None): if start_time: self.latest_event_time = to_timestamp(start_time) counter = self._max_retries - while counter > 0: + while counter > 0 and not self.cancelled: LOG.debug("Tailing logs from %s starting at %s", self.cw_log_group, str(self.latest_event_time)) counter -= 1 - self.load_time_period(to_datetime(self.latest_event_time), filter_pattern=filter_pattern) + try: + self.load_time_period(to_datetime(self.latest_event_time), filter_pattern=filter_pattern) + except ClientError as err: + error_code = err.response.get("Error", {}).get("Code") + if error_code == "ThrottlingException": + # if throttled, increase poll interval by 1 second each time + if self._poll_interval == 1: + self._poll_interval += 1 + else: + self._poll_interval **= 2 + LOG.warning( + "Throttled by CloudWatch Logs API, consider pulling logs for certain resources. " + "Increasing the poll interval time for resource %s to %s seconds", + self.cw_log_group, + self._poll_interval, + ) + else: + # if error is other than throttling, re-raise it + LOG.error("Failed while fetching new log events", exc_info=err) + raise err # This poll fetched logs. Reset the retry counter and set the timestamp for next poll if self.had_data: @@ -92,12 +114,23 @@ def load_time_period( while True: LOG.debug("Fetching logs from CloudWatch with parameters %s", kwargs) - result = self.logs_client.filter_log_events(**kwargs) + try: + result = self.logs_client.filter_log_events(**kwargs) + self._invalid_log_group = False + except self.logs_client.exceptions.ResourceNotFoundException: + if not self._invalid_log_group: + LOG.debug( + "The specified log group %s does not exist. " + "This may be due to your resource have not been invoked yet.", + self.cw_log_group, + ) + self._invalid_log_group = True + break - # Several events will be returned. Yield one at a time + # Several events will be returned. Consume one at a time for event in result.get("events", []): self.had_data = True - cw_event = CWLogEvent(self.cw_log_group, event, self.resource_name) + cw_event = CWLogEvent(self.cw_log_group, dict(event), self.resource_name) if cw_event.timestamp > self.latest_event_time: self.latest_event_time = cw_event.timestamp @@ -109,3 +142,6 @@ def load_time_period( kwargs["nextToken"] = next_token if not next_token: break + + def load_events(self, event_ids: List[Any]): + LOG.debug("Loading specific events are not supported via CloudWatch Log Group") diff --git a/samcli/lib/observability/observability_info_puller.py b/samcli/lib/observability/observability_info_puller.py index b6d6f2b906..30103e95bb 100644 --- a/samcli/lib/observability/observability_info_puller.py +++ b/samcli/lib/observability/observability_info_puller.py @@ -4,7 +4,9 @@ import logging from abc import ABC, abstractmethod from datetime import datetime -from typing import List, Optional, Generic, TypeVar, Any +from typing import List, Optional, Generic, TypeVar, Any, Sequence + +from samcli.lib.utils.async_utils import AsyncContext LOG = logging.getLogger(__name__) @@ -43,6 +45,9 @@ class ObservabilityPuller(ABC): Interface definition for pulling observability information. """ + # used to cancel indefinitely running processes (eg: tail) + cancelled: bool = False + @abstractmethod def tail(self, start_time: Optional[datetime] = None, filter_pattern: Optional[str] = None): """ @@ -72,6 +77,17 @@ def load_time_period( Optional parameter to filter events with given string """ + @abstractmethod + def load_events(self, event_ids: List[Any]): + """ + This method will load specific events which is given by the event_ids parameter + + Parameters + ---------- + event_ids : List[str] + List of event ids that will be pulled + """ + # pylint: disable=fixme # fixme add ABC parent class back once we bump the pylint to a version 2.8.2 or higher @@ -141,3 +157,64 @@ def consume(self, event: ObservabilityEvent): event = mapper.map(event) LOG.debug("Calling consumer (%s) for event (%s)", self._consumer, event) self._consumer.consume(event) + + +class ObservabilityCombinedPuller(ObservabilityPuller): + """ + A decorator class which will contain multiple ObservabilityPuller instance and pull information from each of them + """ + + def __init__(self, pullers: Sequence[ObservabilityPuller]): + """ + Parameters + ---------- + pullers : List[ObservabilityPuller] + List of pullers which will be managed by this class + """ + self._pullers = pullers + + def tail(self, start_time: Optional[datetime] = None, filter_pattern: Optional[str] = None): + """ + Implementation of ObservabilityPuller.tail method with AsyncContext. + It will create tasks by calling tail methods of all given pullers, and execute them in async + """ + async_context = AsyncContext() + for puller in self._pullers: + LOG.debug("Adding task 'tail' for puller (%s)", puller) + async_context.add_async_task(puller.tail, start_time, filter_pattern) + LOG.debug("Running all 'tail' tasks in parallel") + try: + async_context.run_async() + except KeyboardInterrupt: + LOG.info(" CTRL+C received, cancelling...") + for puller in self._pullers: + puller.cancelled = True + + def load_time_period( + self, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + filter_pattern: Optional[str] = None, + ): + """ + Implementation of ObservabilityPuller.load_time_period method with AsyncContext. + It will create tasks by calling load_time_period methods of all given pullers, and execute them in async + """ + async_context = AsyncContext() + for puller in self._pullers: + LOG.debug("Adding task 'load_time_period' for puller (%s)", puller) + async_context.add_async_task(puller.load_time_period, start_time, end_time, filter_pattern) + LOG.debug("Running all 'load_time_period' tasks in parallel") + async_context.run_async() + + def load_events(self, event_ids: List[Any]): + """ + Implementation of ObservabilityPuller.load_events method with AsyncContext. + It will create tasks by calling load_events methods of all given pullers, and execute them in async + """ + async_context = AsyncContext() + for puller in self._pullers: + LOG.debug("Adding task 'load_events' for puller (%s)", puller) + async_context.add_async_task(puller.load_events, event_ids) + LOG.debug("Running all 'load_time_period' tasks in parallel") + async_context.run_async() diff --git a/samcli/lib/observability/xray_traces/__init__.py b/samcli/lib/observability/xray_traces/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/samcli/lib/observability/xray_traces/xray_event_mappers.py b/samcli/lib/observability/xray_traces/xray_event_mappers.py new file mode 100644 index 0000000000..126e275067 --- /dev/null +++ b/samcli/lib/observability/xray_traces/xray_event_mappers.py @@ -0,0 +1,166 @@ +""" +Contains mapper implementations of XRay events +""" +import json +from copy import deepcopy +from datetime import datetime +from typing import List + +from samcli.lib.observability.observability_info_puller import ObservabilityEventMapper +from samcli.lib.observability.xray_traces.xray_events import ( + XRayTraceEvent, + XRayTraceSegment, + XRayServiceGraphEvent, + XRayGraphServiceInfo, +) +from samcli.lib.utils.time import to_utc, utc_to_timestamp, timestamp_to_iso + + +class XRayTraceConsoleMapper(ObservabilityEventMapper[XRayTraceEvent]): + """ + Maps given XRayTraceEvent.message field into printable format to use it in the console consumer + """ + + def map(self, event: XRayTraceEvent) -> XRayTraceEvent: + formatted_segments = self.format_segments(event.segments) + iso_formatted_timestamp = datetime.fromtimestamp(event.timestamp).isoformat() + mapped_message = ( + f"\nXRay Event at ({iso_formatted_timestamp}) with id ({event.id}) and duration ({event.duration:.3f}s)" + f"{formatted_segments}" + ) + event.message = mapped_message + + return event + + def format_segments(self, segments: List[XRayTraceSegment], level: int = 0) -> str: + """ + Prints given segment information back to console. + + Parameters + ---------- + segments : List[XRayTraceEvent] + List of segments which will be printed into console + level : int + Optional level value which will be used to make the indentation of each segment. Default value is 0 + """ + formatted_str = "" + for segment in segments: + formatted_str += f"\n{' ' * level} - {segment.get_duration():.3f}s - {segment.name}" + if segment.http_status: + formatted_str += f" [HTTP: {segment.http_status}]" + formatted_str += self.format_segments(segment.sub_segments, (level + 1)) + + return formatted_str + + +class XRayTraceJSONMapper(ObservabilityEventMapper[XRayTraceEvent]): + """ + Original response from xray client contains json in an escaped string. This mapper re-constructs Json object again + and converts into JSON string that can be printed into console. + """ + + # pylint: disable=R0201 + def map(self, event: XRayTraceEvent) -> XRayTraceEvent: + mapped_event = deepcopy(event.event) + segments = [segment.document for segment in event.segments] + mapped_event["Segments"] = segments + event.event = mapped_event + event.message = json.dumps(mapped_event) + return event + + +class XRayServiceGraphConsoleMapper(ObservabilityEventMapper[XRayServiceGraphEvent]): + """ + Maps given XRayServiceGraphEvent.message field into printable format to use it in the console consumer + """ + + def map(self, event: XRayServiceGraphEvent) -> XRayServiceGraphEvent: + formatted_services = self.format_services(event.services) + mapped_message = "\nNew XRay Service Graph" + mapped_message += f"\n Start time: {event.start_time}" + mapped_message += f"\n End time: {event.end_time}" + mapped_message += formatted_services + event.message = mapped_message + + return event + + def format_services(self, services: List[XRayGraphServiceInfo]) -> str: + """ + Prints given services information back to console. + + Parameters + ---------- + services : List[XRayGraphServiceInfo] + List of services which will be printed into console + """ + formatted_str = "" + for service in services: + formatted_str += f"\n Reference Id: {service.id}" + formatted_str += f"{ ' - (Root)' if service.is_root else ' -'}" + formatted_str += f" {service.type} - {service.name}" + formatted_str += f" - Edges: {self.format_edges(service)}" + formatted_str += self.format_summary_statistics(service, 1) + + return formatted_str + + @staticmethod + def format_edges(service: XRayGraphServiceInfo) -> str: + edge_ids = service.edge_ids + return str(edge_ids) + + @staticmethod + def format_summary_statistics(service: XRayGraphServiceInfo, level) -> str: + """ + Prints given summary statistics information back to console. + + Parameters + ---------- + service: XRayGraphServiceInfo + summary statistics of the service which will be printed into console + level : int + Optional level value which will be used to make the indentation of each segment. Default value is 0 + """ + formatted_str = f"\n{' ' * level} Summary_statistics:" + formatted_str += f"\n{' ' * (level + 1)} - total requests: {service.total_count}" + formatted_str += f"\n{' ' * (level + 1)} - ok count(2XX): {service.ok_count}" + formatted_str += f"\n{' ' * (level + 1)} - error count(4XX): {service.error_count}" + formatted_str += f"\n{' ' * (level + 1)} - fault count(5XX): {service.fault_count}" + formatted_str += f"\n{' ' * (level + 1)} - total response time: {service.response_time}" + return formatted_str + + +class XRayServiceGraphJSONMapper(ObservabilityEventMapper[XRayServiceGraphEvent]): + """ + Original response from xray client contains datetime object. This mapper convert datetime object to iso string, + and converts final JSON object into string. + """ + + def map(self, event: XRayServiceGraphEvent) -> XRayServiceGraphEvent: + mapped_event = deepcopy(event.event) + + self._convert_start_and_end_time_to_iso(mapped_event) + services = mapped_event.get("Services", []) + for service in services: + self._convert_start_and_end_time_to_iso(service) + edges = service.get("Edges", []) + for edge in edges: + self._convert_start_and_end_time_to_iso(edge) + + event.event = mapped_event + event.message = json.dumps(mapped_event) + return event + + def _convert_start_and_end_time_to_iso(self, event): + self.convert_event_datetime_to_iso(event, "StartTime") + self.convert_event_datetime_to_iso(event, "EndTime") + + def convert_event_datetime_to_iso(self, event, datetime_key): + event_datetime = event.get(datetime_key, None) + if event_datetime: + event[datetime_key] = self.convert_local_datetime_to_iso(event_datetime) + + @staticmethod + def convert_local_datetime_to_iso(local_datetime): + utc_datetime = to_utc(local_datetime) + time_stamp = utc_to_timestamp(utc_datetime) + return timestamp_to_iso(time_stamp) diff --git a/samcli/lib/observability/xray_traces/xray_event_puller.py b/samcli/lib/observability/xray_traces/xray_event_puller.py new file mode 100644 index 0000000000..2543af034d --- /dev/null +++ b/samcli/lib/observability/xray_traces/xray_event_puller.py @@ -0,0 +1,152 @@ +""" +This file contains puller implementations for XRay +""" +import logging +import time +from datetime import datetime +from itertools import zip_longest +from typing import Optional, Any, List, Set, Dict + +from botocore.exceptions import ClientError + +from samcli.lib.observability.observability_info_puller import ObservabilityPuller, ObservabilityEventConsumer +from samcli.lib.observability.xray_traces.xray_events import XRayTraceEvent +from samcli.lib.utils.time import to_timestamp, to_datetime + +LOG = logging.getLogger(__name__) + + +class AbstractXRayPuller(ObservabilityPuller): + def __init__( + self, + max_retries: int = 1000, + poll_interval: int = 1, + ): + """ + Parameters + ---------- + max_retries : int + Optional maximum number of retries which can be used to pull information. Default value is 1000 + poll_interval : int + Optional interval value that will be used to wait between calls in tail operation. Default value is 1 + """ + self._max_retries = max_retries + self._poll_interval = poll_interval + self._had_data = False + self.latest_event_time = 0 + + def tail(self, start_time: Optional[datetime] = None, filter_pattern: Optional[str] = None): + if start_time: + self.latest_event_time = to_timestamp(start_time) + + counter = self._max_retries + while counter > 0 and not self.cancelled: + LOG.debug("Tailing XRay traces starting at %s", self.latest_event_time) + + counter -= 1 + try: + self.load_time_period(to_datetime(self.latest_event_time), datetime.utcnow()) + except ClientError as err: + error_code = err.response.get("Error", {}).get("Code") + if error_code == "ThrottlingException": + # if throttled, increase poll interval by 1 second each time + if self._poll_interval == 1: + self._poll_interval += 1 + else: + self._poll_interval **= 2 + LOG.warning( + "Throttled by XRay API, increasing the poll interval time to %s seconds", + self._poll_interval, + ) + else: + # if exception is other than throttling re-raise + LOG.error("Failed while fetching new AWS X-Ray events", exc_info=err) + raise err + + if self._had_data: + counter = self._max_retries + self.latest_event_time += 1 + self._had_data = False + + time.sleep(self._poll_interval) + + +class XRayTracePuller(AbstractXRayPuller): + """ + ObservabilityPuller implementation which pulls XRay trace information by summarizing XRay traces first + and then getting them as a batch later. + """ + + def __init__( + self, xray_client: Any, consumer: ObservabilityEventConsumer, max_retries: int = 1000, poll_interval: int = 1 + ): + """ + Parameters + ---------- + xray_client : boto3.client + XRay boto3 client instance + consumer : ObservabilityEventConsumer + Consumer instance which will process pulled events + max_retries : int + Optional maximum number of retries which can be used to pull information. Default value is 1000 + poll_interval : int + Optional interval value that will be used to wait between calls in tail operation. Default value is 1 + """ + super().__init__(max_retries, poll_interval) + self.xray_client = xray_client + self.consumer = consumer + self._previous_trace_ids: Set[str] = set() + + def load_time_period( + self, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + filter_pattern: Optional[str] = None, + ): + kwargs = {"TimeRangeType": "TraceId", "StartTime": start_time, "EndTime": end_time} + + # first, collect all trace ids in given period + trace_ids = [] + LOG.debug("Fetching XRay trace summaries %s", kwargs) + result_paginator = self.xray_client.get_paginator("get_trace_summaries") + result_iterator = result_paginator.paginate(**kwargs) + for result in result_iterator: + trace_summaries = result.get("TraceSummaries", []) + for trace_summary in trace_summaries: + trace_id = trace_summary.get("Id", None) + is_partial = trace_summary.get("IsPartial", False) + if not is_partial and trace_id not in self._previous_trace_ids: + trace_ids.append(trace_id) + self._previous_trace_ids.add(trace_id) + + # now load collected events + self.load_events(trace_ids) + + def load_events(self, event_ids: List[str]): + if not event_ids: + LOG.debug("Nothing to fetch, empty event_id list given (%s)", event_ids) + return + + # xray client only accepts 5 items at max, so create batches of 5 element arrays + event_batches = zip_longest(*([iter(event_ids)] * 5)) + + for event_batch in event_batches: + kwargs: Dict[str, Any] = {"TraceIds": list(filter(None, event_batch))} + result_paginator = self.xray_client.get_paginator("batch_get_traces") + result_iterator = result_paginator.paginate(**kwargs) + for result in result_iterator: + traces = result.get("Traces", []) + + if not traces: + LOG.debug("No event found with given trace ids %s", str(event_ids)) + + for trace in traces: + self._had_data = True + xray_trace_event = XRayTraceEvent(trace) + + # update latest fetched event + latest_event_time = xray_trace_event.get_latest_event_time() + if latest_event_time > self.latest_event_time: + self.latest_event_time = latest_event_time + + self.consumer.consume(xray_trace_event) diff --git a/samcli/lib/observability/xray_traces/xray_events.py b/samcli/lib/observability/xray_traces/xray_events.py new file mode 100644 index 0000000000..281c073c95 --- /dev/null +++ b/samcli/lib/observability/xray_traces/xray_events.py @@ -0,0 +1,160 @@ +""" +Keeps XRay event definitions +""" +import json +import operator +from typing import List + +from samcli.lib.observability.observability_info_puller import ObservabilityEvent +from samcli.lib.utils.hash import str_checksum + + +start_time_getter = operator.attrgetter("start_time") + + +class XRayTraceEvent(ObservabilityEvent[dict]): + """ + Represents a result of each XRay trace event, which is returned by boto3 client by calling 'batch_get_traces' + See XRayTracePuller + """ + + def __init__(self, event: dict): + super().__init__(event, 0) + self.id = event.get("Id", "") + self.duration = event.get("Duration", 0.0) + self.message = json.dumps(event) + self.segments: List[XRayTraceSegment] = [] + + self._construct_segments(event) + if self.segments: + self.timestamp = self.segments[0].start_time + + def _construct_segments(self, event_dict): + """ + Each event is represented by segment, and it is like a Tree model (each segment also have subsegments). + """ + raw_segments = event_dict.get("Segments", []) + for raw_segment in raw_segments: + segment_document = raw_segment.get("Document", "{}") + self.segments.append(XRayTraceSegment(json.loads(segment_document))) + self.segments.sort(key=start_time_getter) + + def get_latest_event_time(self): + """ + Returns the latest event time for this specific XRayTraceEvent by calling get_latest_event_time for each segment + """ + latest_event_time = 0 + for segment in self.segments: + segment_latest_event_time = segment.get_latest_event_time() + if segment_latest_event_time > latest_event_time: + latest_event_time = segment_latest_event_time + + return latest_event_time + + +class XRayTraceSegment: + """ + Represents each segment information for a XRayTraceEvent + """ + + def __init__(self, document: dict): + self.id = document.get("Id", "") + self.document = document + self.name = document.get("name", "") + self.start_time = document.get("start_time", 0) + self.end_time = document.get("end_time", 0) + self.http_status = document.get("http", {}).get("response", {}).get("status", None) + self.sub_segments: List[XRayTraceSegment] = [] + + sub_segments = document.get("subsegments", []) + for sub_segment in sub_segments: + self.sub_segments.append(XRayTraceSegment(sub_segment)) + self.sub_segments.sort(key=start_time_getter) + + def get_duration(self): + return self.end_time - self.start_time + + def get_latest_event_time(self): + """ + Gets the latest event time by comparing all timestamps (end_time) from current segment and all sub-segments + """ + latest_event_time = self.end_time + for sub_segment in self.sub_segments: + sub_segment_latest_time = sub_segment.get_latest_event_time() + if sub_segment_latest_time > latest_event_time: + latest_event_time = sub_segment_latest_time + + return latest_event_time + + +class XRayServiceGraphEvent(ObservabilityEvent[dict]): + """ + Represents a result of each XRay service graph event, which is returned by boto3 client by calling + 'get_service_graph' See XRayServiceGraphPuller + """ + + def __init__(self, event: dict): + self.services: List[XRayGraphServiceInfo] = [] + self.message = str(event) + self._construct_service(event) + self.start_time = event.get("StartTime", None) + self.end_time = event.get("EndTime", None) + super().__init__(event, 0) + + def _construct_service(self, event_dict): + services = event_dict.get("Services", []) + for service in services: + self.services.append(XRayGraphServiceInfo(service)) + + def get_hash(self): + """ + get the hash of the containing services + """ + services = self.event.get("Services", []) + return str_checksum(str(services)) + + +class XRayGraphServiceInfo: + """ + Represents each services information for a XRayServiceGraphEvent + """ + + def __init__(self, service: dict): + self.id = service.get("ReferenceId", "") + self.document = service + self.name = service.get("Name", "") + self.is_root = service.get("Root", False) + self.type = service.get("Type") + self.edge_ids: List[int] = [] + self.ok_count = 0 + self.error_count = 0 + self.fault_count = 0 + self.total_count = 0 + self.response_time = 0 + self._construct_edge_ids(service.get("Edges", [])) + self._set_summary_statistics(service.get("SummaryStatistics", None)) + + def _construct_edge_ids(self, edges): + """ + covert the edges information to a list of edge reference ids + """ + edge_ids: List[int] = [] + for edge in edges: + edge_ids.append(edge.get("ReferenceId", -1)) + self.edge_ids = edge_ids + + def _set_summary_statistics(self, summary_statistics): + """ + get some useful information from summary statistics + """ + if not summary_statistics: + return + self.ok_count = summary_statistics.get("OkCount", 0) + error_statistics = summary_statistics.get("ErrorStatistics", None) + if error_statistics: + self.error_count = error_statistics.get("TotalCount", 0) + fault_statistics = summary_statistics.get("FaultStatistics", None) + if fault_statistics: + self.fault_count = fault_statistics.get("TotalCount", 0) + self.total_count = summary_statistics.get("TotalCount", 0) + self.response_time = summary_statistics.get("TotalResponseTime", 0) diff --git a/samcli/lib/observability/xray_traces/xray_service_graph_event_puller.py b/samcli/lib/observability/xray_traces/xray_service_graph_event_puller.py new file mode 100644 index 0000000000..4f2b0aa1fc --- /dev/null +++ b/samcli/lib/observability/xray_traces/xray_service_graph_event_puller.py @@ -0,0 +1,72 @@ +""" +This file contains puller implementations for XRay +""" +import logging +from datetime import datetime +from typing import Optional, Any, List, Set + +from samcli.lib.observability.observability_info_puller import ObservabilityEventConsumer +from samcli.lib.observability.xray_traces.xray_event_puller import AbstractXRayPuller +from samcli.lib.observability.xray_traces.xray_events import XRayServiceGraphEvent +from samcli.lib.utils.time import to_utc, utc_to_timestamp + +LOG = logging.getLogger(__name__) + + +class XRayServiceGraphPuller(AbstractXRayPuller): + """ + ObservabilityPuller implementation which pulls XRay Service Graph + """ + + def __init__( + self, xray_client: Any, consumer: ObservabilityEventConsumer, max_retries: int = 1000, poll_interval: int = 1 + ): + """ + Parameters + ---------- + xray_client : boto3.client + XRay boto3 client instance + consumer : ObservabilityEventConsumer + Consumer instance which will process pulled events + max_retries : int + Optional maximum number of retries which can be used to pull information. Default value is 1000 + poll_interval : int + Optional interval value that will be used to wait between calls in tail operation. Default value is 1 + """ + super().__init__(max_retries, poll_interval) + self.xray_client = xray_client + self.consumer = consumer + self._previous_xray_service_graphs: Set[str] = set() + + def load_time_period( + self, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + filter_pattern: Optional[str] = None, + ): + # pull xray traces service graph + kwargs = {"StartTime": start_time, "EndTime": end_time} + result_paginator = self.xray_client.get_paginator("get_service_graph") + result_iterator = result_paginator.paginate(**kwargs) + for result in result_iterator: + services = result.get("Services", []) + + if not services: + LOG.debug("No service graph found%s") + else: + # update latest fetched event + event_end_time = result.get("EndTime", None) + if event_end_time: + utc_end_time = to_utc(event_end_time) + latest_event_time = utc_to_timestamp(utc_end_time) + if latest_event_time > self.latest_event_time: + self.latest_event_time = latest_event_time + 1 + + self._had_data = True + xray_service_graph_event = XRayServiceGraphEvent(result) + if xray_service_graph_event.get_hash() not in self._previous_xray_service_graphs: + self.consumer.consume(xray_service_graph_event) + self._previous_xray_service_graphs.add(xray_service_graph_event.get_hash()) + + def load_events(self, event_ids: List[str]): + LOG.debug("Loading specific service graph events are not supported via XRay Service Graph") diff --git a/samcli/lib/package/artifact_exporter.py b/samcli/lib/package/artifact_exporter.py index bfa4ad9bf8..7b464b69ce 100644 --- a/samcli/lib/package/artifact_exporter.py +++ b/samcli/lib/package/artifact_exporter.py @@ -20,7 +20,7 @@ from botocore.utils import set_value_from_jmespath -from samcli.commands._utils.resources import ( +from samcli.lib.utils.resources import ( AWS_SERVERLESS_FUNCTION, AWS_CLOUDFORMATION_STACK, RESOURCES_WITH_LOCAL_PATHS, diff --git a/samcli/lib/package/packageable_resources.py b/samcli/lib/package/packageable_resources.py index 140aeaedf2..ec7e0b7c0b 100644 --- a/samcli/lib/package/packageable_resources.py +++ b/samcli/lib/package/packageable_resources.py @@ -25,7 +25,7 @@ is_ecr_url, ) -from samcli.commands._utils.resources import ( +from samcli.lib.utils.resources import ( AWS_SERVERLESSREPO_APPLICATION, AWS_SERVERLESS_FUNCTION, AWS_SERVERLESS_API, diff --git a/samcli/lib/providers/cfn_api_provider.py b/samcli/lib/providers/cfn_api_provider.py index 8acbb67fab..babf15d9ae 100644 --- a/samcli/lib/providers/cfn_api_provider.py +++ b/samcli/lib/providers/cfn_api_provider.py @@ -9,29 +9,32 @@ from samcli.lib.providers.cfn_base_api_provider import CfnBaseApiProvider from samcli.lib.providers.api_collector import ApiCollector +from samcli.lib.utils.resources import ( + AWS_APIGATEWAY_METHOD, + AWS_APIGATEWAY_RESOURCE, + AWS_APIGATEWAY_RESTAPI, + AWS_APIGATEWAY_STAGE, + AWS_APIGATEWAY_V2_API, + AWS_APIGATEWAY_V2_INTEGRATION, + AWS_APIGATEWAY_V2_ROUTE, + AWS_APIGATEWAY_V2_STAGE, +) + LOG = logging.getLogger(__name__) class CfnApiProvider(CfnBaseApiProvider): - APIGATEWAY_RESTAPI = "AWS::ApiGateway::RestApi" - APIGATEWAY_STAGE = "AWS::ApiGateway::Stage" - APIGATEWAY_RESOURCE = "AWS::ApiGateway::Resource" - APIGATEWAY_METHOD = "AWS::ApiGateway::Method" - APIGATEWAY_V2_API = "AWS::ApiGatewayV2::Api" - APIGATEWAY_V2_INTEGRATION = "AWS::ApiGatewayV2::Integration" - APIGATEWAY_V2_ROUTE = "AWS::ApiGatewayV2::Route" - APIGATEWAY_V2_STAGE = "AWS::ApiGatewayV2::Stage" METHOD_BINARY_TYPE = "CONVERT_TO_BINARY" HTTP_API_PROTOCOL_TYPE = "HTTP" TYPES = [ - APIGATEWAY_RESTAPI, - APIGATEWAY_STAGE, - APIGATEWAY_RESOURCE, - APIGATEWAY_METHOD, - APIGATEWAY_V2_API, - APIGATEWAY_V2_INTEGRATION, - APIGATEWAY_V2_ROUTE, - APIGATEWAY_V2_STAGE, + AWS_APIGATEWAY_RESTAPI, + AWS_APIGATEWAY_STAGE, + AWS_APIGATEWAY_RESOURCE, + AWS_APIGATEWAY_METHOD, + AWS_APIGATEWAY_V2_API, + AWS_APIGATEWAY_V2_INTEGRATION, + AWS_APIGATEWAY_V2_ROUTE, + AWS_APIGATEWAY_V2_STAGE, ] def extract_resources(self, stacks: List[Stack], collector: ApiCollector, cwd: Optional[str] = None) -> None: @@ -54,22 +57,22 @@ def extract_resources(self, stacks: List[Stack], collector: ApiCollector, cwd: O resources = stack.resources for logical_id, resource in resources.items(): resource_type = resource.get(CfnBaseApiProvider.RESOURCE_TYPE) - if resource_type == CfnApiProvider.APIGATEWAY_RESTAPI: + if resource_type == AWS_APIGATEWAY_RESTAPI: self._extract_cloud_formation_route(stack.stack_path, logical_id, resource, collector, cwd=cwd) - if resource_type == CfnApiProvider.APIGATEWAY_STAGE: + if resource_type == AWS_APIGATEWAY_STAGE: self._extract_cloud_formation_stage(resources, resource, collector) - if resource_type == CfnApiProvider.APIGATEWAY_METHOD: + if resource_type == AWS_APIGATEWAY_METHOD: self._extract_cloud_formation_method(stack.stack_path, resources, logical_id, resource, collector) - if resource_type == CfnApiProvider.APIGATEWAY_V2_API: + if resource_type == AWS_APIGATEWAY_V2_API: self._extract_cfn_gateway_v2_api(stack.stack_path, logical_id, resource, collector, cwd=cwd) - if resource_type == CfnApiProvider.APIGATEWAY_V2_ROUTE: + if resource_type == AWS_APIGATEWAY_V2_ROUTE: self._extract_cfn_gateway_v2_route(stack.stack_path, resources, logical_id, resource, collector) - if resource_type == CfnApiProvider.APIGATEWAY_V2_STAGE: + if resource_type == AWS_APIGATEWAY_V2_STAGE: self._extract_cfn_gateway_v2_stage(resources, resource, collector) @staticmethod @@ -136,7 +139,7 @@ def _extract_cloud_formation_stage( if not logical_id: raise InvalidSamTemplateException("The AWS::ApiGateway::Stage must have a RestApiId property") rest_api_resource_type = resources.get(logical_id, {}).get("Type") - if rest_api_resource_type != CfnApiProvider.APIGATEWAY_RESTAPI: + if rest_api_resource_type != AWS_APIGATEWAY_RESTAPI: raise InvalidSamTemplateException( "The AWS::ApiGateway::Stage must have a valid RestApiId that points to RestApi resource {}".format( logical_id @@ -387,7 +390,7 @@ def _extract_cfn_gateway_v2_stage( if not api_id: raise InvalidSamTemplateException("The AWS::ApiGatewayV2::Stage must have a ApiId property") api_resource_type = resources.get(api_id, {}).get("Type") - if api_resource_type != CfnApiProvider.APIGATEWAY_V2_API: + if api_resource_type != AWS_APIGATEWAY_V2_API: raise InvalidSamTemplateException( "The AWS::ApiGatewayV2::Stag must have a valid ApiId that points to Api resource {}".format(api_id) ) @@ -449,7 +452,7 @@ def _get_route_function_name( integration_resource = resources.get(integration_id, {}) resource_type = integration_resource.get("Type") - if resource_type == CfnApiProvider.APIGATEWAY_V2_INTEGRATION: + if resource_type == AWS_APIGATEWAY_V2_INTEGRATION: properties = integration_resource.get("Properties", {}) integration_uri = properties.get("IntegrationUri") payload_format_version = properties.get("PayloadFormatVersion") diff --git a/samcli/lib/providers/exceptions.py b/samcli/lib/providers/exceptions.py index 60328f781f..f589706964 100644 --- a/samcli/lib/providers/exceptions.py +++ b/samcli/lib/providers/exceptions.py @@ -2,6 +2,12 @@ Exceptions used by providers """ +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: # pragma: no cover + from samcli.lib.providers.provider import ResourceIdentifier + class InvalidLayerReference(Exception): """ @@ -16,3 +22,36 @@ def __init__(self) -> None: class RemoteStackLocationNotSupported(Exception): pass + + +class MissingCodeUri(Exception): + """Exception when Function or Lambda resources do not have CodeUri specified""" + + +class MissingLocalDefinition(Exception): + """Exception when a resource does not have local path in it's property""" + + _resource_identifier: "ResourceIdentifier" + _property_name: str + + def __init__(self, resource_identifier: "ResourceIdentifier", property_name: str) -> None: + """Exception when a resource does not have local path in it's property + + Parameters + ---------- + resource_identifier : ResourceIdentifier + Resource Identifer + property_name : str + Property name that's missing + """ + self._resource_identifier = resource_identifier + self._property_name = property_name + super().__init__(f"Resource {str(resource_identifier)} does not have {property_name} specified.") + + @property + def resource_identifier(self) -> "ResourceIdentifier": + return self._resource_identifier + + @property + def property_name(self) -> str: + return self._property_name diff --git a/samcli/lib/providers/provider.py b/samcli/lib/providers/provider.py index afee0467a9..f20827e2b2 100644 --- a/samcli/lib/providers/provider.py +++ b/samcli/lib/providers/provider.py @@ -7,7 +7,7 @@ import os import posixpath from collections import namedtuple -from typing import Set, NamedTuple, Optional, List, Dict, Union, cast, Iterator, TYPE_CHECKING +from typing import Any, Set, NamedTuple, Optional, List, Dict, Tuple, Union, cast, Iterator, TYPE_CHECKING from samcli.commands.local.cli_common.user_exceptions import ( InvalidLayerVersionArn, @@ -17,7 +17,7 @@ from samcli.lib.providers.sam_base_provider import SamBaseProvider from samcli.lib.utils.architecture import X86_64 -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover # avoid circular import, https://docs.python.org/3/library/typing.html#typing.TYPE_CHECKING from samcli.local.apigw.local_apigw_service import Route @@ -56,7 +56,7 @@ class Function(NamedTuple): # to get credentials to run the container with. This gives a much higher fidelity simulation of cloud Lambda. rolearn: Optional[str] # List of Layers - layers: List + layers: List["LayerVersion"] # Event events: Optional[List] # Metadata @@ -493,14 +493,176 @@ def get_output_template_path(self, build_root: str) -> str: return os.path.join(build_root, self.stack_path.replace(posixpath.sep, os.path.sep), "template.yaml") +class ResourceIdentifier: + """Resource identifier for representing a resource with nested stack support""" + + _stack_path: str + _logical_id: str + + def __init__(self, resource_identifier_str: str): + """ + Parameters + ---------- + resource_identifier_str : str + Resource identifier in the format of: + Stack1/Stack2/ResourceID + """ + parts = resource_identifier_str.rsplit(posixpath.sep, 1) + if len(parts) == 1: + self._stack_path = "" + self._logical_id = parts[0] + else: + self._stack_path = parts[0] + self._logical_id = parts[1] + + @property + def stack_path(self) -> str: + """ + Returns + ------- + str + Stack path of the resource. + This can be empty string if resource is in the root stack. + """ + return self._stack_path + + @property + def logical_id(self) -> str: + """ + Returns + ------- + str + Logical ID of the resource. + """ + return self._logical_id + + def __str__(self) -> str: + return self.stack_path + posixpath.sep + self.logical_id if self.stack_path else self.logical_id + + def __eq__(self, other: object) -> bool: + return str(self) == str(other) if isinstance(other, ResourceIdentifier) else False + + def __hash__(self) -> int: + return hash(str(self)) + + def get_full_path(stack_path: str, logical_id: str) -> str: """ Return the unique posix path-like identifier while will used for identify a resource from resources in a multi-stack situation """ + if not stack_path: + return logical_id return posixpath.join(stack_path, logical_id) +def get_resource_by_id( + stacks: List[Stack], identifier: ResourceIdentifier, explicit_nested: bool = False +) -> Optional[Dict[str, Any]]: + """Seach resource in stacks based on identifier + + Parameters + ---------- + stacks : List[Stack] + List of stacks to be searched + identifier : ResourceIdentifier + Resource identifier for the resource to be returned + explicit_nested : bool, optional + Set to True to only search in root stack if stack_path does not exist. + Otherwise, all stacks will be searched in order to find matching logical ID. + If stack_path does exist in identifier, this option will be ignored and behave as if it is True + + Returns + ------- + Dict + Resource dict + """ + search_all_stacks = not identifier.stack_path and not explicit_nested + for stack in stacks: + if stack.stack_path == identifier.stack_path or search_all_stacks: + resource = stack.resources.get(identifier.logical_id) + if resource: + return cast(Dict[str, Any], resource) + return None + + +def get_resource_ids_by_type(stacks: List[Stack], resource_type: str) -> List[ResourceIdentifier]: + """Return list of resource IDs + + Parameters + ---------- + stacks : List[Stack] + List of stacks + resource_type : str + Resource type to be used for searching related resources. + + Returns + ------- + List[ResourceIdentifier] + List of ResourceIdentifiers with the type provided + """ + resource_ids: List[ResourceIdentifier] = list() + for stack in stacks: + for resource_id, resource in stack.resources.items(): + if resource.get("Type", "") == resource_type: + resource_ids.append(ResourceIdentifier(get_full_path(stack.stack_path, resource_id))) + return resource_ids + + +def get_all_resource_ids(stacks: List[Stack]) -> List[ResourceIdentifier]: + """Return all resource IDs in stacks + + Parameters + ---------- + stacks : List[Stack] + List of stacks + + Returns + ------- + List[ResourceIdentifier] + List of ResourceIdentifiers + """ + resource_ids: List[ResourceIdentifier] = list() + for stack in stacks: + for resource_id, _ in stack.resources.items(): + resource_ids.append(ResourceIdentifier(get_full_path(stack.stack_path, resource_id))) + return resource_ids + + +def get_unique_resource_ids( + stacks: List[Stack], + resource_ids: Optional[Union[List[str], Tuple[str]]], + resource_types: Optional[Union[List[str], Tuple[str]]], +) -> Set[ResourceIdentifier]: + """Get unique resource IDs for resource_ids and resource_types + + Parameters + ---------- + stacks : List[Stack] + Stacks + resource_ids : Optional[Union[List[str], Tuple[str]]] + Resource ID strings + resource_types : Optional[Union[List[str], Tuple[str]]] + Resource types + + Returns + ------- + Set[ResourceIdentifier] + Set of ResourceIdentifier either in resource_ids or has the type in resource_types + """ + output_resource_ids: Set[ResourceIdentifier] = set() + if resource_ids: + for resources_id in resource_ids: + output_resource_ids.add(ResourceIdentifier(resources_id)) + + if resource_types: + for resource_type in resource_types: + resource_type_ids = get_resource_ids_by_type(stacks, resource_type) + for resource_id in resource_type_ids: + output_resource_ids.add(resource_id) + return output_resource_ids + + def _get_build_dir(resource: Union[Function, LayerVersion], build_root: str) -> str: """ Return the build directory to place build artifact diff --git a/samcli/lib/providers/sam_api_provider.py b/samcli/lib/providers/sam_api_provider.py index 0ad44eb7b8..1fd48dfde5 100644 --- a/samcli/lib/providers/sam_api_provider.py +++ b/samcli/lib/providers/sam_api_provider.py @@ -8,15 +8,13 @@ from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException from samcli.lib.providers.provider import Stack from samcli.local.apigw.local_apigw_service import Route +from samcli.lib.utils.resources import AWS_SERVERLESS_FUNCTION, AWS_SERVERLESS_API, AWS_SERVERLESS_HTTPAPI LOG = logging.getLogger(__name__) class SamApiProvider(CfnBaseApiProvider): - SERVERLESS_FUNCTION = "AWS::Serverless::Function" - SERVERLESS_API = "AWS::Serverless::Api" - SERVERLESS_HTTP_API = "AWS::Serverless::HttpApi" - TYPES = [SERVERLESS_FUNCTION, SERVERLESS_API, SERVERLESS_HTTP_API] + TYPES = [AWS_SERVERLESS_FUNCTION, AWS_SERVERLESS_API, AWS_SERVERLESS_HTTPAPI] _EVENT_TYPE_API = "Api" _EVENT_TYPE_HTTP_API = "HttpApi" _FUNCTION_EVENT = "Events" @@ -46,11 +44,11 @@ def extract_resources(self, stacks: List[Stack], collector: ApiCollector, cwd: O for stack in stacks: for logical_id, resource in stack.resources.items(): resource_type = resource.get(CfnBaseApiProvider.RESOURCE_TYPE) - if resource_type == SamApiProvider.SERVERLESS_FUNCTION: + if resource_type == AWS_SERVERLESS_FUNCTION: self._extract_routes_from_function(stack.stack_path, logical_id, resource, collector) - if resource_type == SamApiProvider.SERVERLESS_API: + if resource_type == AWS_SERVERLESS_API: self._extract_from_serverless_api(stack.stack_path, logical_id, resource, collector, cwd=cwd) - if resource_type == SamApiProvider.SERVERLESS_HTTP_API: + if resource_type == AWS_SERVERLESS_HTTPAPI: self._extract_from_serverless_http(stack.stack_path, logical_id, resource, collector, cwd=cwd) collector.routes = self.merge_routes(collector) @@ -156,7 +154,7 @@ def _extract_routes_from_function( Path of the stack the resource is located logical_id : str - Logical ID of the resourc + Logical ID of the resource function_resource : dict Contents of the function resource including its properties diff --git a/samcli/lib/providers/sam_base_provider.py b/samcli/lib/providers/sam_base_provider.py index ffa2e7eb0a..60edbc3ed6 100644 --- a/samcli/lib/providers/sam_base_provider.py +++ b/samcli/lib/providers/sam_base_provider.py @@ -5,7 +5,12 @@ import logging from typing import Any, Dict, Optional, cast, Iterable, Union -from samcli.commands._utils.resources import AWS_SERVERLESS_APPLICATION, AWS_CLOUDFORMATION_STACK +from samcli.lib.utils.resources import ( + AWS_LAMBDA_FUNCTION, + AWS_SERVERLESS_FUNCTION, + AWS_LAMBDA_LAYERVERSION, + AWS_SERVERLESS_LAYERVERSION, +) from samcli.lib.iac.plugins_interfaces import Stack from samcli.lib.intrinsic_resolver.intrinsic_property_resolver import IntrinsicResolver from samcli.lib.intrinsic_resolver.intrinsics_symbol_table import IntrinsicsSymbolTable @@ -22,24 +27,18 @@ class SamBaseProvider: Base class for SAM Template providers """ - SERVERLESS_FUNCTION = "AWS::Serverless::Function" - LAMBDA_FUNCTION = "AWS::Lambda::Function" - SERVERLESS_LAYER = "AWS::Serverless::LayerVersion" - LAMBDA_LAYER = "AWS::Lambda::LayerVersion" - SERVERLESS_APPLICATION = AWS_SERVERLESS_APPLICATION - CLOUDFORMATION_STACK = AWS_CLOUDFORMATION_STACK DEFAULT_CODEURI = "." CODE_PROPERTY_KEYS = { - LAMBDA_FUNCTION: "Code", - SERVERLESS_FUNCTION: "CodeUri", - LAMBDA_LAYER: "Content", - SERVERLESS_LAYER: "ContentUri", + AWS_LAMBDA_FUNCTION: "Code", + AWS_SERVERLESS_FUNCTION: "CodeUri", + AWS_LAMBDA_LAYERVERSION: "Content", + AWS_SERVERLESS_LAYERVERSION: "ContentUri", } IMAGE_PROPERTY_KEYS = { - LAMBDA_FUNCTION: "Code", - SERVERLESS_FUNCTION: "ImageUri", + AWS_LAMBDA_FUNCTION: "Code", + AWS_SERVERLESS_FUNCTION: "ImageUri", } def get(self, name: str) -> Optional[Any]: diff --git a/samcli/lib/providers/sam_function_provider.py b/samcli/lib/providers/sam_function_provider.py index 03057bff4f..e634751ff2 100644 --- a/samcli/lib/providers/sam_function_provider.py +++ b/samcli/lib/providers/sam_function_provider.py @@ -4,6 +4,12 @@ import logging from typing import Dict, List, Optional, cast, Iterator, Any +from samcli.lib.utils.resources import ( + AWS_LAMBDA_FUNCTION, + AWS_LAMBDA_LAYERVERSION, + AWS_SERVERLESS_FUNCTION, + AWS_SERVERLESS_LAYERVERSION, +) from samcli.commands.local.cli_common.user_exceptions import InvalidLayerVersionArn from samcli.lib.providers.exceptions import InvalidLayerReference from samcli.lib.utils.colors import Colored @@ -129,7 +135,7 @@ def _extract_functions( if resource_metadata: resource_properties["Metadata"] = resource_metadata - if resource_type in [SamFunctionProvider.SERVERLESS_FUNCTION, SamFunctionProvider.LAMBDA_FUNCTION]: + if resource_type in [AWS_SERVERLESS_FUNCTION, AWS_LAMBDA_FUNCTION]: resource_package_type = resource_properties.get("PackageType", ZIP) code_property_key = SamBaseProvider.CODE_PROPERTY_KEYS[resource_type] @@ -156,7 +162,7 @@ def _extract_functions( SamFunctionProvider._warn_imageuri_extraction(resource_type, name, image_property_key) continue - if resource_type == SamFunctionProvider.SERVERLESS_FUNCTION: + if resource_type == AWS_SERVERLESS_FUNCTION: layers = SamFunctionProvider._parse_layer_info( stack, resource_properties.get("Layers", []), @@ -172,7 +178,7 @@ def _extract_functions( ) result[function.full_path] = function - elif resource_type == SamFunctionProvider.LAMBDA_FUNCTION: + elif resource_type == AWS_LAMBDA_FUNCTION: layers = SamFunctionProvider._parse_layer_info( stack, resource_properties.get("Layers", []), @@ -429,8 +435,8 @@ def _locate_layer_from_ref( layer_logical_id = cast(str, layer.get("Ref")) layer_resource = stack.resources.get(layer_logical_id) if not layer_resource or layer_resource.get("Type", "") not in ( - SamFunctionProvider.SERVERLESS_LAYER, - SamFunctionProvider.LAMBDA_LAYER, + AWS_SERVERLESS_LAYERVERSION, + AWS_LAMBDA_LAYERVERSION, ): raise InvalidLayerReference() @@ -439,7 +445,7 @@ def _locate_layer_from_ref( compatible_runtimes = layer_properties.get("CompatibleRuntimes") codeuri: Optional[str] = None - if resource_type in [SamFunctionProvider.LAMBDA_LAYER, SamFunctionProvider.SERVERLESS_LAYER]: + if resource_type in [AWS_LAMBDA_LAYERVERSION, AWS_SERVERLESS_LAYERVERSION]: code_property_key = SamBaseProvider.CODE_PROPERTY_KEYS[resource_type] if SamBaseProvider._is_s3_location(layer_properties.get(code_property_key)): # Content can be a dictionary of S3 Bucket/Key or a S3 URI, neither of which are supported diff --git a/samcli/lib/providers/sam_layer_provider.py b/samcli/lib/providers/sam_layer_provider.py index fc7f27b2f5..a8ac1e4124 100644 --- a/samcli/lib/providers/sam_layer_provider.py +++ b/samcli/lib/providers/sam_layer_provider.py @@ -5,6 +5,7 @@ import posixpath from typing import List, Dict, Optional +from samcli.lib.utils.resources import AWS_LAMBDA_LAYERVERSION, AWS_SERVERLESS_LAYERVERSION from .provider import LayerVersion, Stack from .sam_base_provider import SamBaseProvider from .sam_stack_provider import SamLocalStackProvider @@ -86,7 +87,7 @@ def _extract_layers(self) -> List[LayerVersion]: resource_type = resource.get("Type") resource_properties = resource.get("Properties", {}) - if resource_type in [SamBaseProvider.LAMBDA_LAYER, SamBaseProvider.SERVERLESS_LAYER]: + if resource_type in [AWS_LAMBDA_LAYERVERSION, AWS_SERVERLESS_LAYERVERSION]: code_property_key = SamBaseProvider.CODE_PROPERTY_KEYS[resource_type] if SamBaseProvider._is_s3_location(resource_properties.get(code_property_key)): # Content can be a dictionary of S3 Bucket/Key or a S3 URI, neither of which are supported diff --git a/samcli/lib/providers/sam_stack_provider.py b/samcli/lib/providers/sam_stack_provider.py index 758f028121..d95258f0c5 100644 --- a/samcli/lib/providers/sam_stack_provider.py +++ b/samcli/lib/providers/sam_stack_provider.py @@ -10,6 +10,7 @@ from samcli.lib.providers.exceptions import RemoteStackLocationNotSupported from samcli.lib.providers.provider import Stack, get_full_path from samcli.lib.providers.sam_base_provider import SamBaseProvider +from samcli.lib.utils.resources import AWS_CLOUDFORMATION_STACK, AWS_SERVERLESS_APPLICATION LOG = logging.getLogger(__name__) @@ -107,11 +108,11 @@ def _extract_stacks(self) -> None: stack: Optional[Stack] = None try: - if resource_type == SamLocalStackProvider.SERVERLESS_APPLICATION: + if resource_type == AWS_SERVERLESS_APPLICATION: stack = SamLocalStackProvider._convert_sam_application_resource( self._template_file, self._stack_path, name, resource_properties ) - if resource_type == SamLocalStackProvider.CLOUDFORMATION_STACK: + if resource_type == AWS_CLOUDFORMATION_STACK: stack = SamLocalStackProvider._convert_cfn_stack_resource( self._template_file, self._stack_path, name, resource_properties ) diff --git a/samcli/lib/sync/__init__.py b/samcli/lib/sync/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/samcli/lib/sync/continuous_sync_flow_executor.py b/samcli/lib/sync/continuous_sync_flow_executor.py new file mode 100644 index 0000000000..2adf2e1f21 --- /dev/null +++ b/samcli/lib/sync/continuous_sync_flow_executor.py @@ -0,0 +1,130 @@ +"""SyncFlowExecutor that will run continuously until stop is called.""" +import time +import logging + +from typing import Callable, Optional +from concurrent.futures.thread import ThreadPoolExecutor + +from dataclasses import dataclass + +from samcli.lib.sync.exceptions import SyncFlowException +from samcli.lib.sync.sync_flow import SyncFlow +from samcli.lib.sync.sync_flow_executor import SyncFlowExecutor, SyncFlowFuture, SyncFlowTask, default_exception_handler + +LOG = logging.getLogger(__name__) + + +@dataclass(frozen=True, eq=True) +class DelayedSyncFlowTask(SyncFlowTask): + """Data struct for individual SyncFlow execution tasks""" + + # Time in seconds of when the task was initially queued + queue_time: float + + # Number of seconds this task should stay in queue before being executed + wait_time: float + + +class ContinuousSyncFlowExecutor(SyncFlowExecutor): + """SyncFlowExecutor that continuously runs and executes SyncFlows. + Call stop() to stop the executor""" + + # Flag for whether the executor should be stopped at the next available time + _stop_flag: bool + + def __init__(self) -> None: + super().__init__() + self._stop_flag = False + + def stop(self, should_stop=True) -> None: + """Stop executor after all current SyncFlows are finished.""" + with self._flow_queue_lock: + self._stop_flag = should_stop + if should_stop: + self._flow_queue.queue.clear() + + def should_stop(self) -> bool: + """ + Returns + ------- + bool + Should executor stop execution on the next available time. + """ + return self._stop_flag + + def _can_exit(self): + return self.should_stop() and super()._can_exit() + + def _submit_sync_flow_task( + self, executor: ThreadPoolExecutor, sync_flow_task: SyncFlowTask + ) -> Optional[SyncFlowFuture]: + """Submit SyncFlowTask to be executed by ThreadPoolExecutor + and return its future + Adds additional time checks for DelayedSyncFlowTask + + Parameters + ---------- + executor : ThreadPoolExecutor + THreadPoolExecutor to be used for execution + sync_flow_task : SyncFlowTask + SyncFlowTask to be executed. + + Returns + ------- + Optional[SyncFlowFuture] + Returns SyncFlowFuture generated by the SyncFlowTask. + Can be None if the task cannot be executed yet. + """ + if ( + isinstance(sync_flow_task, DelayedSyncFlowTask) + and sync_flow_task.wait_time + sync_flow_task.queue_time > time.time() + ): + return None + + return super()._submit_sync_flow_task(executor, sync_flow_task) + + def _add_sync_flow_task(self, task: SyncFlowTask) -> None: + """Add SyncFlowTask to the queue + Skips if the executor is in the state of being shut down. + + Parameters + ---------- + task : SyncFlowTask + SyncFlowTask to be added. + """ + if self.should_stop(): + LOG.debug( + "%s is skipped from queueing as executor is in the process of stopping.", task.sync_flow.log_prefix + ) + return + + super()._add_sync_flow_task(task) + + def add_delayed_sync_flow(self, sync_flow: SyncFlow, dedup: bool = True, wait_time: float = 0) -> None: + """Add a SyncFlow to queue to be executed + Locks will be set with LockDistributor + + Parameters + ---------- + sync_flow : SyncFlow + SyncFlow to be executed + dedup : bool + SyncFlow will not be added if this flag is True and has a duplicate in the queue + wait_time : float + Minimum number of seconds before SyncFlow executes + """ + self._add_sync_flow_task(DelayedSyncFlowTask(sync_flow, dedup, time.time(), wait_time)) + + def execute( + self, exception_handler: Optional[Callable[[SyncFlowException], None]] = default_exception_handler + ) -> None: + """Blocking continuous execution of the SyncFlows + + Parameters + ---------- + exception_handler : Optional[Callable[[Exception], None]], optional + Function to be called if an exception is raised during the execution of a SyncFlow, + by default default_exception_handler.__func__ + """ + super().execute(exception_handler=exception_handler) + self.stop(should_stop=False) diff --git a/samcli/lib/sync/exceptions.py b/samcli/lib/sync/exceptions.py new file mode 100644 index 0000000000..3df096877b --- /dev/null +++ b/samcli/lib/sync/exceptions.py @@ -0,0 +1,173 @@ +"""Exceptions related to sync functionalities""" +from typing import Dict, Optional, TYPE_CHECKING + +if TYPE_CHECKING: # pragma: no cover + from samcli.lib.sync.sync_flow import SyncFlow + + +class SyncFlowException(Exception): + """Exception wrapper for exceptions raised in SyncFlows""" + + _sync_flow: "SyncFlow" + _exception: Exception + + def __init__(self, sync_flow: "SyncFlow", exception: Exception): + """ + Parameters + ---------- + sync_flow : SyncFlow + SyncFlow that raised the exception + exception : Exception + exception raised + """ + super().__init__(f"SyncFlow Exception for {sync_flow.log_name}") + self._sync_flow = sync_flow + self._exception = exception + + @property + def sync_flow(self) -> "SyncFlow": + return self._sync_flow + + @property + def exception(self) -> Exception: + return self._exception + + +class InfraSyncRequiredError(Exception): + """Exception used if SyncFlow cannot handle the sync and an infra sync is required""" + + _resource_identifier: Optional[str] + _reason: Optional[str] + + def __init__(self, resource_identifier: Optional[str] = None, reason: Optional[str] = ""): + """ + Parameters + ---------- + resource_identifier : str + Logical resource identifier + reason : str + Reason for requiring infra sync + """ + super().__init__(f"{resource_identifier} cannot be code synced.") + self._resource_identifier = resource_identifier + self._reason = reason + + @property + def resource_identifier(self) -> Optional[str]: + """ + Returns + ------- + str + Resource identifier of the resource that does not have a remote/physical counterpart + """ + return self._resource_identifier + + @property + def reason(self) -> Optional[str]: + """ + Returns + ------- + str + Reason to why the SyncFlow cannot sync the resource + """ + return self._reason + + +class MissingPhysicalResourceError(Exception): + """Exception used for not having a remote/physical counterpart for a local stack resource""" + + _resource_identifier: Optional[str] + _physical_resource_mapping: Optional[Dict[str, str]] + + def __init__( + self, resource_identifier: Optional[str] = None, physical_resource_mapping: Optional[Dict[str, str]] = None + ): + """ + Parameters + ---------- + resource_identifier : str + Logical resource identifier + physical_resource_mapping: Dict[str, str] + Current mapping between logical and physical IDs + """ + super().__init__(f"{resource_identifier} is not found in remote.") + self._resource_identifier = resource_identifier + self._physical_resource_mapping = physical_resource_mapping + + @property + def resource_identifier(self) -> Optional[str]: + """ + Returns + ------- + str + Resource identifier of the resource that does not have a remote/physical counterpart + """ + return self._resource_identifier + + @property + def physical_resource_mapping(self) -> Optional[Dict[str, str]]: + """ + Returns + ------- + Optional[Dict[str, str]] + Physical ID mapping for resources when the excecption was raised + """ + return self._physical_resource_mapping + + +class NoLayerVersionsFoundError(Exception): + """This is used when we try to list all versions for layer, but we found none""" + + _layer_name_arn: str + + def __init__(self, layer_name_arn: str): + """ + Parameters + ---------- + layer_name_arn : str + Layer ARN without version info at the end of it + """ + super().__init__(f"{layer_name_arn} doesn't have any versions in remote.") + self._layer_name_arn = layer_name_arn + + @property + def layer_name_arn(self) -> str: + """ + Returns + ------- + str + Layer ARN without version info at the end of it + """ + return self._layer_name_arn + + +class MissingLockException(Exception): + """Exception for not having an associated lock to be used.""" + + +class MissingFunctionBuildDefinition(Exception): + """This is used when no build definition found for particular function""" + + _function_logical_id: str + + def __init__(self, function_logical_id: str): + super().__init__(f"Build definition for {function_logical_id} can't be found") + self._function_logical_id = function_logical_id + + @property + def function_logical_id(self) -> str: + return self._function_logical_id + + +class InvalidRuntimeDefinitionForFunction(Exception): + """This is used when no Runtime information is defined for a function resource""" + + _function_logical_id: str + + def __init__(self, function_logical_id): + super().__init__(f"Invalid Runtime definition for {function_logical_id}") + self._function_logical_id = function_logical_id + + @property + def function_logical_id(self): + return self._function_logical_id diff --git a/samcli/lib/sync/flows/__init__.py b/samcli/lib/sync/flows/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/samcli/lib/sync/flows/alias_version_sync_flow.py b/samcli/lib/sync/flows/alias_version_sync_flow.py new file mode 100644 index 0000000000..f82ecabe62 --- /dev/null +++ b/samcli/lib/sync/flows/alias_version_sync_flow.py @@ -0,0 +1,89 @@ +"""SyncFlow for Lambda Function Alias and Version""" +import logging +from typing import Any, Dict, List, Optional, TYPE_CHECKING, cast + +from boto3.session import Session + +from samcli.lib.providers.provider import Stack +from samcli.lib.sync.sync_flow import SyncFlow, ResourceAPICall + +if TYPE_CHECKING: # pragma: no cover + from samcli.commands.deploy.deploy_context import DeployContext + from samcli.commands.build.build_context import BuildContext + +LOG = logging.getLogger(__name__) + + +class AliasVersionSyncFlow(SyncFlow): + """This SyncFlow is used for updating Lambda Function version and its associating Alias. + Currently, this is created after a FunctionSyncFlow is finished. + """ + + _function_identifier: str + _alias_name: str + _lambda_client: Any + + def __init__( + self, + function_identifier: str, + alias_name: str, + build_context: "BuildContext", + deploy_context: "DeployContext", + physical_id_mapping: Dict[str, str], + stacks: Optional[List[Stack]] = None, + ): + """ + Parameters + ---------- + function_identifier : str + Function resource identifier that need to have associated Alias and Version updated. + alias_name : str + Alias name for the function + build_context : BuildContext + BuildContext + deploy_context : DeployContext + DeployContext + physical_id_mapping : Dict[str, str] + Physical ID Mapping + stacks : Optional[List[Stack]] + Stacks + """ + super().__init__( + build_context, + deploy_context, + physical_id_mapping, + log_name=f"Alias {alias_name} and Version of {function_identifier}", + stacks=stacks, + ) + self._function_identifier = function_identifier + self._alias_name = alias_name + self._lambda_client = None + + def set_up(self) -> None: + super().set_up() + self._lambda_client = cast(Session, self._session).client("lambda") + + def gather_resources(self) -> None: + pass + + def compare_remote(self) -> bool: + return False + + def sync(self) -> None: + function_physical_id = self.get_physical_id(self._function_identifier) + version = self._lambda_client.publish_version(FunctionName=function_physical_id).get("Version") + LOG.debug("%sCreated new function version: %s", self.log_prefix, version) + if version: + self._lambda_client.update_alias( + FunctionName=function_physical_id, Name=self._alias_name, FunctionVersion=version + ) + + def gather_dependencies(self) -> List[SyncFlow]: + return [] + + def _get_resource_api_calls(self) -> List[ResourceAPICall]: + return [] + + def _equality_keys(self) -> Any: + """Combination of function identifier and alias name can used to identify each unique SyncFlow""" + return self._function_identifier, self._alias_name diff --git a/samcli/lib/sync/flows/auto_dependency_layer_sync_flow.py b/samcli/lib/sync/flows/auto_dependency_layer_sync_flow.py new file mode 100644 index 0000000000..04d58b3867 --- /dev/null +++ b/samcli/lib/sync/flows/auto_dependency_layer_sync_flow.py @@ -0,0 +1,137 @@ +""" +Contains sync flow implementation for Auto Dependency Layer +""" +import hashlib +import logging +import os +import tempfile +import uuid +from typing import List, TYPE_CHECKING, Dict, cast, Optional + +from samcli.lib.bootstrap.nested_stack.nested_stack_builder import NestedStackBuilder +from samcli.lib.bootstrap.nested_stack.nested_stack_manager import NestedStackManager +from samcli.lib.build.build_graph import BuildGraph +from samcli.lib.package.utils import make_zip +from samcli.lib.providers.provider import Function, Stack +from samcli.lib.providers.sam_function_provider import SamFunctionProvider +from samcli.lib.sync.exceptions import ( + MissingFunctionBuildDefinition, + InvalidRuntimeDefinitionForFunction, + NoLayerVersionsFoundError, +) +from samcli.lib.sync.flows.layer_sync_flow import AbstractLayerSyncFlow +from samcli.lib.sync.flows.zip_function_sync_flow import ZipFunctionSyncFlow +from samcli.lib.sync.sync_flow import SyncFlow +from samcli.lib.utils.hash import file_checksum + +if TYPE_CHECKING: # pragma: no cover + from samcli.commands.deploy.deploy_context import DeployContext + from samcli.commands.build.build_context import BuildContext + +LOG = logging.getLogger(__name__) + + +class AutoDependencyLayerSyncFlow(AbstractLayerSyncFlow): + """ + Auto Dependency Layer, Layer Sync flow. + It creates auto dependency layer files out of function dependencies, and syncs layer code and then updates + the function configuration with new layer version + + This flow is not instantiated from factory method, please see AutoDependencyLayerParentSyncFlow + """ + + _function_identifier: str + _build_graph: Optional[BuildGraph] + + def __init__( + self, + function_identifier: str, + build_graph: BuildGraph, + build_context: "BuildContext", + deploy_context: "DeployContext", + physical_id_mapping: Dict[str, str], + stacks: List[Stack], + ): + super().__init__( + NestedStackBuilder.get_layer_logical_id(function_identifier), + build_context, + deploy_context, + physical_id_mapping, + stacks, + ) + self._function_identifier = function_identifier + self._build_graph = build_graph + + def set_up(self) -> None: + super().set_up() + + # find layer's physical id + layer_name = NestedStackBuilder.get_layer_name(self._deploy_context.stack_name, self._function_identifier) + layer_versions = self._lambda_client.list_layer_versions(LayerName=layer_name).get("LayerVersions", []) + if not layer_versions: + raise NoLayerVersionsFoundError(layer_name) + self._layer_arn = layer_versions[0].get("LayerVersionArn").rsplit(":", 1)[0] + + def gather_resources(self) -> None: + function_build_definitions = cast(BuildGraph, self._build_graph).get_function_build_definitions() + if not function_build_definitions: + raise MissingFunctionBuildDefinition(self._function_identifier) + + self._artifact_folder = NestedStackManager.update_layer_folder( + self._build_context.build_dir, + function_build_definitions[0].dependencies_dir, + self._layer_identifier, + self._function_identifier, + self._get_compatible_runtimes()[0], + ) + zip_file_path = os.path.join(tempfile.gettempdir(), "data-" + uuid.uuid4().hex) + self._zip_file = make_zip(zip_file_path, self._artifact_folder) + self._local_sha = file_checksum(cast(str, self._zip_file), hashlib.sha256()) + + def _get_dependent_functions(self) -> List[Function]: + function = SamFunctionProvider(cast(List[Stack], self._stacks)).get(self._function_identifier) + return [function] if function else [] + + def _get_compatible_runtimes(self) -> List[str]: + function = SamFunctionProvider(cast(List[Stack], self._stacks)).get(self._function_identifier) + if not function or not function.runtime: + raise InvalidRuntimeDefinitionForFunction(self._function_identifier) + return [function.runtime] + + +class AutoDependencyLayerParentSyncFlow(ZipFunctionSyncFlow): + """ + Parent sync flow for auto dependency layer + + It builds function with regular ZipFunctionSyncFlow, and then adds _AutoDependencyLayerSyncFlow to start syncing + dependency layer. + """ + + def gather_dependencies(self) -> List[SyncFlow]: + """ + Return auto dependency layer sync flow along with parent dependencies + """ + parent_dependencies = super().gather_dependencies() + + function_build_definitions = cast(BuildGraph, self._build_graph).get_function_build_definitions() + if not function_build_definitions: + raise MissingFunctionBuildDefinition(self._function.name) + + # don't queue up auto dependency layer, if dependencies are not changes + need_dependency_layer_sync = function_build_definitions[0].download_dependencies + if need_dependency_layer_sync: + parent_dependencies.append( + AutoDependencyLayerSyncFlow( + self._function_identifier, + cast(BuildGraph, self._build_graph), + self._build_context, + self._deploy_context, + self._physical_id_mapping, + cast(List[Stack], self._stacks), + ) + ) + return parent_dependencies + + @staticmethod + def _combine_dependencies() -> bool: + return False diff --git a/samcli/lib/sync/flows/function_sync_flow.py b/samcli/lib/sync/flows/function_sync_flow.py new file mode 100644 index 0000000000..31c114f6ed --- /dev/null +++ b/samcli/lib/sync/flows/function_sync_flow.py @@ -0,0 +1,104 @@ +"""Base SyncFlow for Lambda Function""" +import logging +from typing import Any, Dict, List, TYPE_CHECKING, cast + +from boto3.session import Session + +from samcli.lib.providers.sam_function_provider import SamFunctionProvider +from samcli.lib.sync.flows.alias_version_sync_flow import AliasVersionSyncFlow +from samcli.lib.providers.provider import Function, Stack +from samcli.local.lambdafn.exceptions import FunctionNotFound + +from samcli.lib.sync.sync_flow import SyncFlow + +if TYPE_CHECKING: # pragma: no cover + from samcli.commands.deploy.deploy_context import DeployContext + from samcli.commands.build.build_context import BuildContext + +LOG = logging.getLogger(__name__) + + +class FunctionSyncFlow(SyncFlow): + _function_identifier: str + _function_provider: SamFunctionProvider + _function: Function + _lambda_client: Any + _lambda_waiter: Any + _lambda_waiter_config: Dict[str, Any] + + def __init__( + self, + function_identifier: str, + build_context: "BuildContext", + deploy_context: "DeployContext", + physical_id_mapping: Dict[str, str], + stacks: List[Stack], + ): + """ + Parameters + ---------- + function_identifier : str + Function resource identifier that need to be synced. + build_context : BuildContext + BuildContext + deploy_context : DeployContext + DeployContext + physical_id_mapping : Dict[str, str] + Physical ID Mapping + stacks : Optional[List[Stack]] + Stacks + """ + super().__init__( + build_context, + deploy_context, + physical_id_mapping, + log_name="Lambda Function " + function_identifier, + stacks=stacks, + ) + self._function_identifier = function_identifier + self._function_provider = self._build_context.function_provider + self._function = cast(Function, self._function_provider.functions.get(self._function_identifier)) + self._lambda_client = None + self._lambda_waiter = None + self._lambda_waiter_config = {"Delay": 1, "MaxAttempts": 60} + + def set_up(self) -> None: + super().set_up() + self._lambda_client = cast(Session, self._session).client("lambda") + self._lambda_waiter = self._lambda_client.get_waiter("function_updated") + + def gather_dependencies(self) -> List[SyncFlow]: + """Gathers alias and versions related to a function. + Currently only handles serverless function AutoPublishAlias field + since a manually created function version resource behaves statically in a stack. + Redeploying a version resource through CFN will not create a new version. + """ + LOG.debug("%sWaiting on Remote Function Update", self.log_prefix) + self._lambda_waiter.wait( + FunctionName=self.get_physical_id(self._function_identifier), WaiterConfig=self._lambda_waiter_config + ) + LOG.debug("%sRemote Function Updated", self.log_prefix) + sync_flows: List[SyncFlow] = list() + + function_resource = self._get_resource(self._function_identifier) + if not function_resource: + raise FunctionNotFound(f"Unable to find function {self._function_identifier}") + + auto_publish_alias_name = function_resource.get("Properties", dict()).get("AutoPublishAlias", None) + if auto_publish_alias_name: + sync_flows.append( + AliasVersionSyncFlow( + self._function_identifier, + auto_publish_alias_name, + self._build_context, + self._deploy_context, + self._physical_id_mapping, + self._stacks, + ) + ) + LOG.debug("%sCreated Alias and Version SyncFlow", self.log_prefix) + + return sync_flows + + def _equality_keys(self): + return self._function_identifier diff --git a/samcli/lib/sync/flows/generic_api_sync_flow.py b/samcli/lib/sync/flows/generic_api_sync_flow.py new file mode 100644 index 0000000000..1eadcc3e60 --- /dev/null +++ b/samcli/lib/sync/flows/generic_api_sync_flow.py @@ -0,0 +1,89 @@ +"""SyncFlow interface for HttpApi and RestApi""" +import logging +from typing import Any, Dict, List, Optional, TYPE_CHECKING, cast + +from samcli.lib.sync.sync_flow import SyncFlow, ResourceAPICall +from samcli.lib.providers.provider import Stack, get_resource_by_id, ResourceIdentifier + +# BuildContext and DeployContext will only be imported for type checking to improve performance +# since no istances of contexts will be instantiated in this class +if TYPE_CHECKING: # pragma: no cover + from samcli.commands.build.build_context import BuildContext + from samcli.commands.deploy.deploy_context import DeployContext + +LOG = logging.getLogger(__name__) + + +class GenericApiSyncFlow(SyncFlow): + """SyncFlow interface for HttpApi and RestApi""" + + _api_client: Any + _api_identifier: str + _definition_uri: Optional[str] + _stacks: List[Stack] + _swagger_body: Optional[bytes] + + def __init__( + self, + api_identifier: str, + build_context: "BuildContext", + deploy_context: "DeployContext", + physical_id_mapping: Dict[str, str], + log_name: str, + stacks: List[Stack], + ): + """ + Parameters + ---------- + api_identifier : str + HttpApi resource identifier that needs to have associated Api updated. + build_context : BuildContext + BuildContext used for build related parameters + deploy_context : BuildContext + DeployContext used for this deploy related parameters + physical_id_mapping : Dict[str, str] + Mapping between resource logical identifier and physical identifier + log_name: str + Log name passed from subclasses, HttpApi or RestApi + stacks : List[Stack], optional + List of stacks containing a root stack and optional nested stacks + """ + super().__init__( + build_context, + deploy_context, + physical_id_mapping, + log_name=log_name, + stacks=stacks, + ) + self._api_identifier = api_identifier + + def gather_resources(self) -> None: + self._definition_uri = self._get_definition_file(self._api_identifier) + self._swagger_body = self._process_definition_file() + + def _process_definition_file(self) -> Optional[bytes]: + if self._definition_uri is None: + return None + with open(self._definition_uri, "rb") as swagger_file: + swagger_body = swagger_file.read() + return swagger_body + + def _get_definition_file(self, api_identifier: str) -> Optional[str]: + api_resource = get_resource_by_id(self._stacks, ResourceIdentifier(api_identifier)) + if api_resource is None: + return None + properties = api_resource.get("Properties", {}) + definition_file = properties.get("DefinitionUri") + return cast(Optional[str], definition_file) + + def compare_remote(self) -> bool: + return False + + def gather_dependencies(self) -> List[SyncFlow]: + return [] + + def _get_resource_api_calls(self) -> List[ResourceAPICall]: + return [] + + def _equality_keys(self) -> Any: + return self._api_identifier diff --git a/samcli/lib/sync/flows/http_api_sync_flow.py b/samcli/lib/sync/flows/http_api_sync_flow.py new file mode 100644 index 0000000000..0d7b9d614f --- /dev/null +++ b/samcli/lib/sync/flows/http_api_sync_flow.py @@ -0,0 +1,64 @@ +"""SyncFlow for HttpApi""" +import logging +from typing import Dict, List, TYPE_CHECKING, cast + +from boto3.session import Session + +from samcli.lib.sync.flows.generic_api_sync_flow import GenericApiSyncFlow +from samcli.lib.providers.provider import ResourceIdentifier, Stack +from samcli.lib.providers.exceptions import MissingLocalDefinition + +# BuildContext and DeployContext will only be imported for type checking to improve performance +# since no instances of contexts will be instantiated in this class +if TYPE_CHECKING: # pragma: no cover + from samcli.commands.build.build_context import BuildContext + from samcli.commands.deploy.deploy_context import DeployContext + +LOG = logging.getLogger(__name__) + + +class HttpApiSyncFlow(GenericApiSyncFlow): + """SyncFlow for HttpApi's""" + + def __init__( + self, + api_identifier: str, + build_context: "BuildContext", + deploy_context: "DeployContext", + physical_id_mapping: Dict[str, str], + stacks: List[Stack], + ): + """ + Parameters + ---------- + api_identifier : str + HttpApi resource identifier that needs to have associated HttpApi updated. + build_context : BuildContext + BuildContext used for build related parameters + deploy_context : BuildContext + DeployContext used for this deploy related parameters + physical_id_mapping : Dict[str, str] + Mapping between resource logical identifier and physical identifier + stacks : List[Stack], optional + List of stacks containing a root stack and optional nested stacks + """ + super().__init__( + api_identifier, + build_context, + deploy_context, + physical_id_mapping, + log_name="HttpApi " + api_identifier, + stacks=stacks, + ) + + def set_up(self) -> None: + super().set_up() + self._api_client = cast(Session, self._session).client("apigatewayv2") + + def sync(self) -> None: + api_physical_id = self.get_physical_id(self._api_identifier) + if self._definition_uri is None: + raise MissingLocalDefinition(ResourceIdentifier(self._api_identifier), "DefinitionUri") + LOG.debug("%sTrying to import HttpAPI through client", self.log_prefix) + response = self._api_client.reimport_api(ApiId=api_physical_id, Body=self._swagger_body) + LOG.debug("%sImport HttpApi Result: %s", self.log_prefix, response) diff --git a/samcli/lib/sync/flows/image_function_sync_flow.py b/samcli/lib/sync/flows/image_function_sync_flow.py new file mode 100644 index 0000000000..ae8ebc9b6f --- /dev/null +++ b/samcli/lib/sync/flows/image_function_sync_flow.py @@ -0,0 +1,110 @@ +"""SyncFlow for Image based Lambda Functions""" +import logging +from typing import Any, Dict, List, Optional, TYPE_CHECKING, cast + +import docker +from boto3.session import Session +from docker.client import DockerClient + +from samcli.lib.providers.provider import Stack +from samcli.lib.sync.flows.function_sync_flow import FunctionSyncFlow +from samcli.lib.package.ecr_uploader import ECRUploader + +from samcli.lib.build.app_builder import ApplicationBuilder +from samcli.lib.sync.sync_flow import ResourceAPICall + +if TYPE_CHECKING: # pragma: no cover + from samcli.commands.deploy.deploy_context import DeployContext + from samcli.commands.build.build_context import BuildContext + +LOG = logging.getLogger(__name__) + + +class ImageFunctionSyncFlow(FunctionSyncFlow): + _ecr_client: Any + _docker_client: Optional[DockerClient] + _image_name: Optional[str] + + def __init__( + self, + function_identifier: str, + build_context: "BuildContext", + deploy_context: "DeployContext", + physical_id_mapping: Dict[str, str], + stacks: List[Stack], + docker_client: Optional[DockerClient] = None, + ): + """ + Parameters + ---------- + function_identifier : str + Image function resource identifier that need to be synced. + build_context : BuildContext + BuildContext + deploy_context : DeployContext + DeployContext + physical_id_mapping : Dict[str, str] + Physical ID Mapping + stacks : Optional[List[Stack]] + Stacks + docker_client : Optional[DockerClient] + Docker client to be used for building and uploading images. + Defaults to docker.from_env() if None is provided. + """ + super().__init__(function_identifier, build_context, deploy_context, physical_id_mapping, stacks) + self._ecr_client = None + self._image_name = None + self._docker_client = docker_client + + def set_up(self) -> None: + super().set_up() + self._ecr_client = cast(Session, self._session).client("ecr") + if not self._docker_client: + self._docker_client = docker.from_env() + + def gather_resources(self) -> None: + """Build function image and save it in self._image_name""" + builder = ApplicationBuilder( + self._build_context.collect_build_resources(self._function_identifier), + self._build_context.build_dir, + self._build_context.base_dir, + self._build_context.cache_dir, + cached=False, + is_building_specific_resource=True, + manifest_path_override=self._build_context.manifest_path_override, + container_manager=self._build_context.container_manager, + mode=self._build_context.mode, + ) + self._image_name = builder.build().artifacts.get(self._function_identifier) + + def compare_remote(self) -> bool: + return False + + def sync(self) -> None: + if not self._image_name: + LOG.debug("%sSkipping sync. Image name is None.", self.log_prefix) + return + function_physical_id = self.get_physical_id(self._function_identifier) + # Load ECR Repo from --image-repository + ecr_repo = self._deploy_context.image_repository + + # Load ECR Repo from --image-repositories + if ( + not ecr_repo + and self._deploy_context.image_repositories + and isinstance(self._deploy_context.image_repositories, dict) + ): + ecr_repo = self._deploy_context.image_repositories.get(self._function_identifier) + + # Load ECR Repo directly from remote function + if not ecr_repo: + LOG.debug("%sGetting ECR Repo from Remote Function", self.log_prefix) + function_result = self._lambda_client.get_function(FunctionName=function_physical_id) + ecr_repo = function_result.get("Code", dict()).get("ImageUri", "").split(":")[0] + ecr_uploader = ECRUploader(self._docker_client, self._ecr_client, ecr_repo, None) + image_uri = ecr_uploader.upload(self._image_name, self._function_identifier) + + self._lambda_client.update_function_code(FunctionName=function_physical_id, ImageUri=image_uri) + + def _get_resource_api_calls(self) -> List[ResourceAPICall]: + return [] diff --git a/samcli/lib/sync/flows/layer_sync_flow.py b/samcli/lib/sync/flows/layer_sync_flow.py new file mode 100644 index 0000000000..6197c4072b --- /dev/null +++ b/samcli/lib/sync/flows/layer_sync_flow.py @@ -0,0 +1,348 @@ +"""SyncFlow for Layers""" +import base64 +import hashlib +import logging +import os +import re +import tempfile +import uuid +from abc import ABC, abstractmethod +from typing import Any, TYPE_CHECKING, cast, Dict, List, Optional + +from boto3.session import Session +from samcli.lib.build.app_builder import ApplicationBuilder +from samcli.lib.package.utils import make_zip +from samcli.lib.providers.provider import ResourceIdentifier, Stack, get_resource_by_id, Function +from samcli.lib.providers.sam_function_provider import SamFunctionProvider +from samcli.lib.sync.exceptions import MissingPhysicalResourceError, NoLayerVersionsFoundError +from samcli.lib.sync.sync_flow import SyncFlow, ResourceAPICall +from samcli.lib.sync.sync_flow_executor import HELP_TEXT_FOR_SYNC_INFRA +from samcli.lib.utils.hash import file_checksum + +if TYPE_CHECKING: # pragma: no cover + from samcli.commands.build.build_context import BuildContext + from samcli.commands.deploy.deploy_context import DeployContext + +LOG = logging.getLogger(__name__) + + +class AbstractLayerSyncFlow(SyncFlow, ABC): + """ + AbstractLayerSyncFlow contains common operations for a Layer sync. + """ + + _lambda_client: Any + _layer_arn: Optional[str] + _old_layer_version: Optional[int] + _new_layer_version: Optional[int] + _layer_identifier: str + _artifact_folder: Optional[str] + _zip_file: Optional[str] + _local_sha: Optional[str] + + def __init__( + self, + layer_identifier: str, + build_context: "BuildContext", + deploy_context: "DeployContext", + physical_id_mapping: Dict[str, str], + stacks: List[Stack], + ): + super().__init__(build_context, deploy_context, physical_id_mapping, f"Layer {layer_identifier}", stacks) + self._layer_identifier = layer_identifier + self._layer_arn = None + self._old_layer_version = None + self._new_layer_version = None + self._zip_file = None + self._artifact_folder = None + + def set_up(self) -> None: + super().set_up() + self._lambda_client = cast(Session, self._session).client("lambda") + + def compare_remote(self) -> bool: + """ + Compare Sha256 of the deployed layer code vs the one just built, True if they are same, False otherwise + """ + self._old_layer_version = self._get_latest_layer_version() + old_layer_info = self._lambda_client.get_layer_version( + LayerName=self._layer_arn, + VersionNumber=self._old_layer_version, + ) + remote_sha = base64.b64decode(old_layer_info.get("Content", {}).get("CodeSha256", "")).hex() + LOG.debug("%sLocal SHA: %s Remote SHA: %s", self.log_prefix, self._local_sha, remote_sha) + + return self._local_sha == remote_sha + + def _get_latest_layer_version(self): + """Fetches all layer versions from remote and returns the latest one""" + layer_versions = self._lambda_client.list_layer_versions(LayerName=self._layer_arn).get("LayerVersions", []) + if not layer_versions: + raise NoLayerVersionsFoundError(self._layer_arn) + return layer_versions[0].get("Version") + + def sync(self) -> None: + """ + Publish new layer version, and delete the existing (old) one + """ + LOG.debug("%sPublishing new Layer Version", self.log_prefix) + self._new_layer_version = self._publish_new_layer_version() + self._delete_old_layer_version() + + def gather_dependencies(self) -> List[SyncFlow]: + if self._zip_file and os.path.exists(self._zip_file): + os.remove(self._zip_file) + + dependencies: List[SyncFlow] = list() + dependent_functions = self._get_dependent_functions() + if self._stacks: + for function in dependent_functions: + dependencies.append( + FunctionLayerReferenceSync( + function.full_path, + cast(str, self._layer_arn), + cast(int, self._new_layer_version), + self._build_context, + self._deploy_context, + self._physical_id_mapping, + self._stacks, + ) + ) + return dependencies + + def _get_resource_api_calls(self) -> List[ResourceAPICall]: + return [ResourceAPICall(self._layer_identifier, ["Build"])] + + def _equality_keys(self) -> Any: + return self._layer_identifier + + def _publish_new_layer_version(self) -> int: + """ + Publish new layer version and keep new layer version arn so that we can update related functions + """ + compatible_runtimes = self._get_compatible_runtimes() + with open(cast(str, self._zip_file), "rb") as zip_file: + data = zip_file.read() + layer_publish_result = self._lambda_client.publish_layer_version( + LayerName=self._layer_arn, Content={"ZipFile": data}, CompatibleRuntimes=compatible_runtimes + ) + LOG.debug("%sPublish Layer Version Result %s", self.log_prefix, layer_publish_result) + return int(layer_publish_result.get("Version")) + + def _delete_old_layer_version(self) -> None: + """ + Delete old layer version for not hitting the layer version limit + """ + LOG.debug( + "%sDeleting old Layer Version %s:%s", self.log_prefix, self._old_layer_version, self._old_layer_version + ) + delete_layer_version_result = self._lambda_client.delete_layer_version( + LayerName=self._layer_arn, + VersionNumber=self._old_layer_version, + ) + LOG.debug("%sDelete Layer Version Result %s", self.log_prefix, delete_layer_version_result) + + @abstractmethod + def _get_compatible_runtimes(self) -> List[str]: + """ + Returns compatible runtimes of the Layer instance that is going to be synced + + Returns + ------- + List[str] + List of strings which identifies the compatible runtimes for this layer + """ + raise NotImplementedError("_get_compatible_runtimes not implemented") + + @abstractmethod + def _get_dependent_functions(self) -> List[Function]: + """ + Returns list of Function instances, which is depending on this Layer. This information is used to setup + dependency sync flows, which will update each function's configuration with new layer version. + + Returns + ------- + List[Function] + List of Function instances which uses this Layer + """ + raise NotImplementedError("_get_dependent_functions not implemented") + + +class LayerSyncFlow(AbstractLayerSyncFlow): + """SyncFlow for Lambda Layers""" + + _new_layer_version: Optional[int] + + def set_up(self) -> None: + super().set_up() + + # if layer is a serverless layer, its physical id contains hashes, try to find layer resource + if self._layer_identifier not in self._physical_id_mapping: + expression = re.compile(f"^{self._layer_identifier}[0-9a-z]{{10}}$") + for logical_id, _ in self._physical_id_mapping.items(): + # Skip over resources that do exist in the template as generated LayerVersion should not be in there + if get_resource_by_id(cast(List[Stack], self._stacks), ResourceIdentifier(logical_id), True): + continue + # Check if logical ID starts with serverless layer and has 10 characters behind + if not expression.match(logical_id): + continue + + self._layer_arn = self.get_physical_id(logical_id).rsplit(":", 1)[0] + LOG.debug("%sLayer physical name has been set to %s", self.log_prefix, self._layer_identifier) + break + else: + raise MissingPhysicalResourceError( + self._layer_identifier, + self._physical_id_mapping, + ) + else: + self._layer_arn = self.get_physical_id(self._layer_identifier).rsplit(":", 1)[0] + LOG.debug("%sLayer physical name has been set to %s", self.log_prefix, self._layer_identifier) + + def gather_resources(self) -> None: + """Build layer and ZIP it into a temp file in self._zip_file""" + with self._get_lock_chain(): + builder = ApplicationBuilder( + self._build_context.collect_build_resources(self._layer_identifier), + self._build_context.build_dir, + self._build_context.base_dir, + self._build_context.cache_dir, + cached=True, + is_building_specific_resource=True, + manifest_path_override=self._build_context.manifest_path_override, + container_manager=self._build_context.container_manager, + mode=self._build_context.mode, + ) + LOG.debug("%sBuilding Layer", self.log_prefix) + self._artifact_folder = builder.build().artifacts.get(self._layer_identifier) + + zip_file_path = os.path.join(tempfile.gettempdir(), f"data-{uuid.uuid4().hex}") + self._zip_file = make_zip(zip_file_path, self._artifact_folder) + LOG.debug("%sCreated artifact ZIP file: %s", self.log_prefix, self._zip_file) + self._local_sha = file_checksum(cast(str, self._zip_file), hashlib.sha256()) + + def _get_compatible_runtimes(self): + layer_resource = cast(Dict[str, Any], self._get_resource(self._layer_identifier)) + return layer_resource.get("Properties", {}).get("CompatibleRuntimes", []) + + def _get_dependent_functions(self) -> List[Function]: + function_provider = SamFunctionProvider(cast(List[Stack], self._stacks)) + + dependent_functions = [] + for function in function_provider.get_all(): + if self._layer_identifier in [layer.full_path for layer in function.layers]: + LOG.debug( + "%sAdding function %s for updating its Layers with this new version", + self.log_prefix, + function.name, + ) + dependent_functions.append(function) + return dependent_functions + + +class FunctionLayerReferenceSync(SyncFlow): + """ + Used for updating new Layer version for the related functions + """ + + UPDATE_FUNCTION_CONFIGURATION = "UpdateFunctionConfiguration" + + _lambda_client: Any + + _function_identifier: str + _layer_arn: str + _old_layer_version: int + _new_layer_version: int + + def __init__( + self, + function_identifier: str, + layer_arn: str, + new_layer_version: int, + build_context: "BuildContext", + deploy_context: "DeployContext", + physical_id_mapping: Dict[str, str], + stacks: List[Stack], + ): + super().__init__( + build_context, + deploy_context, + physical_id_mapping, + log_name="Function Layer Reference Sync " + function_identifier, + stacks=stacks, + ) + self._function_identifier = function_identifier + self._layer_arn = layer_arn + self._new_layer_version = new_layer_version + + def set_up(self) -> None: + super().set_up() + self._lambda_client = cast(Session, self._session).client("lambda") + + def sync(self) -> None: + """ + First read the current Layers property and update the old layer version arn with new one + then call the update function configuration to update the function with new layer version arn + """ + if not self._locks: + LOG.warning("%sLocks is None", self.log_prefix) + return + lock_key = SyncFlow._get_lock_key( + self._function_identifier, FunctionLayerReferenceSync.UPDATE_FUNCTION_CONFIGURATION + ) + lock = self._locks.get(lock_key) + if not lock: + LOG.warning("%s%s lock is None", self.log_prefix, lock_key) + return + + with lock: + new_layer_arn = f"{self._layer_arn}:{self._new_layer_version}" + + function_physical_id = self.get_physical_id(self._function_identifier) + get_function_result = self._lambda_client.get_function(FunctionName=function_physical_id) + + # get the current layer version arns + layer_arns = [layer.get("Arn") for layer in get_function_result.get("Configuration", {}).get("Layers", [])] + + # Check whether layer version is up to date + if new_layer_arn in layer_arns: + LOG.warning( + "%sLambda Function (%s) is already up to date with new Layer version (%d).", + self.log_prefix, + self._function_identifier, + self._new_layer_version, + ) + return + + # Check function uses layer + old_layer_arn = [layer_arn for layer_arn in layer_arns if layer_arn.startswith(self._layer_arn)] + old_layer_arn = old_layer_arn[0] if len(old_layer_arn) == 1 else None + if not old_layer_arn: + LOG.warning( + "%sLambda Function (%s) does not have layer (%s).%s", + self.log_prefix, + self._function_identifier, + self._layer_arn, + HELP_TEXT_FOR_SYNC_INFRA, + ) + return + + # remove the old layer version arn and add the new one + layer_arns.remove(old_layer_arn) + layer_arns.append(new_layer_arn) + self._lambda_client.update_function_configuration(FunctionName=function_physical_id, Layers=layer_arns) + + def _get_resource_api_calls(self) -> List[ResourceAPICall]: + return [ResourceAPICall(self._function_identifier, [FunctionLayerReferenceSync.UPDATE_FUNCTION_CONFIGURATION])] + + def compare_remote(self) -> bool: + return False + + def gather_resources(self) -> None: + pass + + def gather_dependencies(self) -> List["SyncFlow"]: + return [] + + def _equality_keys(self) -> Any: + return self._function_identifier, self._layer_arn, self._new_layer_version diff --git a/samcli/lib/sync/flows/rest_api_sync_flow.py b/samcli/lib/sync/flows/rest_api_sync_flow.py new file mode 100644 index 0000000000..0535f1eb1c --- /dev/null +++ b/samcli/lib/sync/flows/rest_api_sync_flow.py @@ -0,0 +1,64 @@ +"""SyncFlow for RestApi""" +import logging +from typing import Dict, List, TYPE_CHECKING, cast + +from boto3.session import Session + +from samcli.lib.sync.flows.generic_api_sync_flow import GenericApiSyncFlow +from samcli.lib.providers.provider import ResourceIdentifier, Stack +from samcli.lib.providers.exceptions import MissingLocalDefinition + +# BuildContext and DeployContext will only be imported for type checking to improve performance +# since no instances of contexts will be instantiated in this class +if TYPE_CHECKING: # pragma: no cover + from samcli.commands.build.build_context import BuildContext + from samcli.commands.deploy.deploy_context import DeployContext + +LOG = logging.getLogger(__name__) + + +class RestApiSyncFlow(GenericApiSyncFlow): + """SyncFlow for RestApi's""" + + def __init__( + self, + api_identifier: str, + build_context: "BuildContext", + deploy_context: "DeployContext", + physical_id_mapping: Dict[str, str], + stacks: List[Stack], + ): + """ + Parameters + ---------- + api_identifier : str + RestApi resource identifier that needs to have associated RestApi updated. + build_context : BuildContext + BuildContext used for build related parameters + deploy_context : BuildContext + DeployContext used for this deploy related parameters + physical_id_mapping : Dict[str, str] + Mapping between resource logical identifier and physical identifier + stacks : List[Stack], optional + List of stacks containing a root stack and optional nested stacks + """ + super().__init__( + api_identifier, + build_context, + deploy_context, + physical_id_mapping, + log_name="RestApi " + api_identifier, + stacks=stacks, + ) + + def set_up(self) -> None: + super().set_up() + self._api_client = cast(Session, self._session).client("apigateway") + + def sync(self) -> None: + api_physical_id = self.get_physical_id(self._api_identifier) + if self._definition_uri is None: + raise MissingLocalDefinition(ResourceIdentifier(self._api_identifier), "DefinitionUri") + LOG.debug("%sTrying to put RestAPI through client", self.log_prefix) + response = self._api_client.put_rest_api(restApiId=api_physical_id, mode="overwrite", body=self._swagger_body) + LOG.debug("%sPut RestApi Result: %s", self.log_prefix, response) diff --git a/samcli/lib/sync/flows/stepfunctions_sync_flow.py b/samcli/lib/sync/flows/stepfunctions_sync_flow.py new file mode 100644 index 0000000000..a49949d797 --- /dev/null +++ b/samcli/lib/sync/flows/stepfunctions_sync_flow.py @@ -0,0 +1,113 @@ +"""Base SyncFlow for StepFunctions""" +import logging +from typing import Any, Dict, List, TYPE_CHECKING, cast, Optional + + +from boto3.session import Session + +from samcli.lib.providers.provider import Stack, get_resource_by_id, ResourceIdentifier +from samcli.lib.sync.sync_flow import SyncFlow, ResourceAPICall +from samcli.lib.sync.exceptions import InfraSyncRequiredError +from samcli.lib.providers.exceptions import MissingLocalDefinition + + +if TYPE_CHECKING: # pragma: no cover + from samcli.commands.deploy.deploy_context import DeployContext + from samcli.commands.build.build_context import BuildContext + +LOG = logging.getLogger(__name__) + + +class StepFunctionsSyncFlow(SyncFlow): + _state_machine_identifier: str + _stepfunctions_client: Any + _stacks: List[Stack] + _definition_uri: Optional[str] + _states_definition: Optional[str] + + def __init__( + self, + state_machine_identifier: str, + build_context: "BuildContext", + deploy_context: "DeployContext", + physical_id_mapping: Dict[str, str], + stacks: List[Stack], + ): + """ + Parameters + ---------- + state_machine_identifier : str + State Machine resource identifier that need to be synced. + build_context : BuildContext + BuildContext used for build related parameters + deploy_context : BuildContext + DeployContext used for this deploy related parameters + physical_id_mapping : Dict[str, str] + Mapping between resource logical identifier and physical identifier + stacks : List[Stack], optional + List of stacks containing a root stack and optional nested stacks + """ + super().__init__( + build_context, + deploy_context, + physical_id_mapping, + log_name="StepFunctions " + state_machine_identifier, + stacks=stacks, + ) + self._state_machine_identifier = state_machine_identifier + self._resource = get_resource_by_id(self._stacks, ResourceIdentifier(self._state_machine_identifier)) + self._stepfunctions_client = None + self._definition_uri = None + self._states_definition = None + + def set_up(self) -> None: + super().set_up() + self._stepfunctions_client = cast(Session, self._session).client("stepfunctions") + + def gather_resources(self) -> None: + if not self._resource: + return + definition_substitutions = self._resource.get("Properties", dict()).get("DefinitionSubstitutions", None) + if definition_substitutions: + raise InfraSyncRequiredError(self._state_machine_identifier, "DefinitionSubstitutions field is specified.") + self._definition_uri = self._get_definition_file(self._state_machine_identifier) + self._states_definition = self._process_definition_file() + + def _process_definition_file(self) -> Optional[str]: + if self._definition_uri is None: + return None + with open(self._definition_uri, "r", encoding="utf-8") as states_file: + states_data = states_file.read() + return states_data + + def _get_definition_file(self, state_machine_identifier: str) -> Optional[str]: + if self._resource is None: + return None + properties = self._resource.get("Properties", {}) + definition_file = properties.get("DefinitionUri") + return cast(Optional[str], definition_file) + + def compare_remote(self) -> bool: + # Not comparing with remote right now, instead only making update api calls + # Note: describe state machine has a better rate limit then update state machine + # So if we face any throttling issues, comparing should be desired + return False + + def gather_dependencies(self) -> List[SyncFlow]: + return [] + + def _get_resource_api_calls(self) -> List[ResourceAPICall]: + return [] + + def _equality_keys(self): + return self._state_machine_identifier + + def sync(self) -> None: + state_machine_arn = self.get_physical_id(self._state_machine_identifier) + if self._definition_uri is None: + raise MissingLocalDefinition(ResourceIdentifier(self._state_machine_identifier), "DefinitionUri") + LOG.debug("%sTrying to update State Machine definition", self.log_prefix) + response = self._stepfunctions_client.update_state_machine( + stateMachineArn=state_machine_arn, definition=self._states_definition + ) + LOG.debug("%sUpdate State Machine: %s", self.log_prefix, response) diff --git a/samcli/lib/sync/flows/zip_function_sync_flow.py b/samcli/lib/sync/flows/zip_function_sync_flow.py new file mode 100644 index 0000000000..6510fba610 --- /dev/null +++ b/samcli/lib/sync/flows/zip_function_sync_flow.py @@ -0,0 +1,155 @@ +"""SyncFlow for ZIP based Lambda Functions""" +import hashlib +import logging +import os +import base64 +import tempfile +import uuid + +from contextlib import ExitStack +from typing import Any, Dict, List, Optional, TYPE_CHECKING, cast + +from boto3.session import Session + +from samcli.lib.build.build_graph import BuildGraph +from samcli.lib.providers.provider import Stack + +from samcli.lib.sync.flows.function_sync_flow import FunctionSyncFlow +from samcli.lib.package.s3_uploader import S3Uploader +from samcli.lib.utils.hash import file_checksum +from samcli.lib.package.utils import make_zip + +from samcli.lib.build.app_builder import ApplicationBuilder +from samcli.lib.sync.sync_flow import ResourceAPICall + +if TYPE_CHECKING: # pragma: no cover + from samcli.commands.deploy.deploy_context import DeployContext + from samcli.commands.build.build_context import BuildContext + +LOG = logging.getLogger(__name__) +MAXIMUM_FUNCTION_ZIP_SIZE = 50 * 1024 * 1024 # 50MB limit for Lambda direct ZIP upload + + +class ZipFunctionSyncFlow(FunctionSyncFlow): + """SyncFlow for ZIP based functions""" + + _s3_client: Any + _artifact_folder: Optional[str] + _zip_file: Optional[str] + _local_sha: Optional[str] + _build_graph: Optional[BuildGraph] + + def __init__( + self, + function_identifier: str, + build_context: "BuildContext", + deploy_context: "DeployContext", + physical_id_mapping: Dict[str, str], + stacks: List[Stack], + ): + + """ + Parameters + ---------- + function_identifier : str + ZIP function resource identifier that need to be synced. + build_context : BuildContext + BuildContext + deploy_context : DeployContext + DeployContext + physical_id_mapping : Dict[str, str] + Physical ID Mapping + stacks : Optional[List[Stack]] + Stacks + """ + super().__init__(function_identifier, build_context, deploy_context, physical_id_mapping, stacks) + self._s3_client = None + self._artifact_folder = None + self._zip_file = None + self._local_sha = None + self._build_graph = None + + def set_up(self) -> None: + super().set_up() + self._s3_client = cast(Session, self._session).client("s3") + + def gather_resources(self) -> None: + """Build function and ZIP it into a temp file in self._zip_file""" + with ExitStack() as exit_stack: + if self._function.layers: + exit_stack.enter_context(self._get_lock_chain()) + + builder = ApplicationBuilder( + self._build_context.collect_build_resources(self._function_identifier), + self._build_context.build_dir, + self._build_context.base_dir, + self._build_context.cache_dir, + cached=True, + is_building_specific_resource=True, + manifest_path_override=self._build_context.manifest_path_override, + container_manager=self._build_context.container_manager, + mode=self._build_context.mode, + combine_dependencies=self._combine_dependencies(), + ) + LOG.debug("%sBuilding Function", self.log_prefix) + build_result = builder.build() + self._build_graph = build_result.build_graph + self._artifact_folder = build_result.artifacts.get(self._function_identifier) + + zip_file_path = os.path.join(tempfile.gettempdir(), "data-" + uuid.uuid4().hex) + self._zip_file = make_zip(zip_file_path, self._artifact_folder) + LOG.debug("%sCreated artifact ZIP file: %s", self.log_prefix, self._zip_file) + self._local_sha = file_checksum(cast(str, self._zip_file), hashlib.sha256()) + + def compare_remote(self) -> bool: + remote_info = self._lambda_client.get_function(FunctionName=self.get_physical_id(self._function_identifier)) + remote_sha = base64.b64decode(remote_info["Configuration"]["CodeSha256"]).hex() + LOG.debug("%sLocal SHA: %s Remote SHA: %s", self.log_prefix, self._local_sha, remote_sha) + + return self._local_sha == remote_sha + + def sync(self) -> None: + if not self._zip_file: + LOG.debug("%sSkipping Sync. ZIP file is None.", self.log_prefix) + return + + zip_file_size = os.path.getsize(self._zip_file) + if zip_file_size < MAXIMUM_FUNCTION_ZIP_SIZE: + # Direct upload through Lambda API + LOG.debug("%sUploading Function Directly", self.log_prefix) + with open(self._zip_file, "rb") as zip_file: + data = zip_file.read() + self._lambda_client.update_function_code( + FunctionName=self.get_physical_id(self._function_identifier), ZipFile=data + ) + else: + # Upload to S3 first for oversized ZIPs + LOG.debug("%sUploading Function Through S3", self.log_prefix) + uploader = S3Uploader( + s3_client=self._s3_client, + bucket_name=self._deploy_context.s3_bucket, + prefix=self._deploy_context.s3_prefix, + kms_key_id=self._deploy_context.kms_key_id, + force_upload=True, + no_progressbar=True, + ) + s3_url = uploader.upload_with_dedup(self._zip_file) + s3_key = s3_url[5:].split("/", 1)[1] + self._lambda_client.update_function_code( + FunctionName=self.get_physical_id(self._function_identifier), + S3Bucket=self._deploy_context.s3_bucket, + S3Key=s3_key, + ) + + if os.path.exists(self._zip_file): + os.remove(self._zip_file) + + def _get_resource_api_calls(self) -> List[ResourceAPICall]: + resource_calls = list() + for layer in self._function.layers: + resource_calls.append(ResourceAPICall(layer.full_path, ["Build"])) + return resource_calls + + @staticmethod + def _combine_dependencies() -> bool: + return True diff --git a/samcli/lib/sync/sync_flow.py b/samcli/lib/sync/sync_flow.py new file mode 100644 index 0000000000..164d24e2c1 --- /dev/null +++ b/samcli/lib/sync/sync_flow.py @@ -0,0 +1,295 @@ +"""SyncFlow base class """ +import logging + +from abc import ABC, abstractmethod +from threading import Lock +from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING, cast +from boto3.session import Session + +from samcli.lib.providers.provider import get_resource_by_id + +from samcli.lib.providers.provider import ResourceIdentifier, Stack +from samcli.lib.utils.lock_distributor import LockDistributor, LockChain +from samcli.lib.sync.exceptions import MissingLockException, MissingPhysicalResourceError + +if TYPE_CHECKING: # pragma: no cover + from samcli.commands.deploy.deploy_context import DeployContext + from samcli.commands.build.build_context import BuildContext + +# Logging with multiple processes is not safe. Use a log queue in the future. +# https://docs.python.org/3/howto/logging-cookbook.html#:~:text=Although%20logging%20is%20thread%2Dsafe,across%20multiple%20processes%20in%20Python. +LOG = logging.getLogger(__name__) + + +class ResourceAPICall(NamedTuple): + """Named tuple for a resource and its potential API calls""" + + resource_identifier: str + api_calls: List[str] + + +class SyncFlow(ABC): + """Base class for a SyncFlow""" + + _log_name: str + _build_context: "BuildContext" + _deploy_context: "DeployContext" + _stacks: Optional[List[Stack]] + _session: Optional[Session] + _physical_id_mapping: Dict[str, str] + _locks: Optional[Dict[str, Lock]] + + def __init__( + self, + build_context: "BuildContext", + deploy_context: "DeployContext", + physical_id_mapping: Dict[str, str], + log_name: str, + stacks: Optional[List[Stack]] = None, + ): + """ + Parameters + ---------- + build_context : BuildContext + BuildContext used for build related parameters + deploy_context : BuildContext + DeployContext used for this deploy related parameters + physical_id_mapping : Dict[str, str] + Mapping between resource logical identifier and physical identifier + log_name : str + Name to be used for logging purposes + stacks : List[Stack], optional + List of stacks containing a root stack and optional nested stacks + """ + self._build_context = build_context + self._deploy_context = deploy_context + self._log_name = log_name + self._stacks = stacks + self._session = None + self._physical_id_mapping = physical_id_mapping + self._locks = None + + def set_up(self) -> None: + """Clients and other expensives setups should be handled here instead of constructor""" + self._session = Session(profile_name=self._deploy_context.profile, region_name=self._deploy_context.region) + + @abstractmethod + def gather_resources(self) -> None: + """Local operations that need to be done before comparison and syncing with remote + Ex: Building lambda functions + """ + raise NotImplementedError("gather_resources") + + @abstractmethod + def compare_remote(self) -> bool: + """Comparison between local and remote resources. + This can be used for optimization if comparison is a lot faster than sync. + If the resources are identical, sync and gather dependencies will be skipped. + Simply return False if there is no comparison needed. + Ex: Comparing local Lambda function artifact with remote SHA256 + + Returns + ------- + bool + Return True if local and remote are in sync. Skipping rest of the execution. + Return False otherwise. + """ + raise NotImplementedError("compare_remote") + + @abstractmethod + def sync(self) -> None: + """Step that syncs local resources with remote. + Ex: Call UpdateFunctionCode for Lambda Functions + """ + raise NotImplementedError("sync") + + @abstractmethod + def gather_dependencies(self) -> List["SyncFlow"]: + """Gather a list of SyncFlows that should be executed after the current change. + This can be sync flows for other resources that depends on the current one. + Ex: Update Lambda functions if a layer sync flow creates a new version. + + Returns + ------ + List[SyncFlow] + List of sync flows that need to be executed after the current one finishes. + """ + raise NotImplementedError("update_dependencies") + + @abstractmethod + def _get_resource_api_calls(self) -> List[ResourceAPICall]: + """Get resources and their associating API calls. This is used for locking purposes. + Returns + ------- + Dict[str, List[str]] + Key as resource logical ID + Value as list of api calls that the resource can make + """ + raise NotImplementedError("_get_resource_api_calls") + + def get_lock_keys(self) -> List[str]: + """Get a list of function + API calls that can be used as keys for LockDistributor + + Returns + ------- + List[str] + List of keys for all resources and their API calls + """ + lock_keys = list() + for resource_api_calls in self._get_resource_api_calls(): + for api_call in resource_api_calls.api_calls: + lock_keys.append(SyncFlow._get_lock_key(resource_api_calls.resource_identifier, api_call)) + return lock_keys + + def set_locks_with_distributor(self, distributor: LockDistributor): + """Set locks to be used with a LockDistributor. Keys should be generated using get_lock_keys(). + + Parameters + ---------- + distributor : LockDistributor + Lock distributor + """ + self.set_locks_with_dict(distributor.get_locks(self.get_lock_keys())) + + def set_locks_with_dict(self, locks: Dict[str, Lock]): + """Set locks to be used. Keys should be generated using get_lock_keys(). + + Parameters + ---------- + locks : Dict[str, Lock] + Dict of locks with keys from get_lock_keys() + """ + self._locks = locks + + @staticmethod + def _get_lock_key(logical_id: str, api_call: str) -> str: + """Get a single lock key for a pair of resource and API call. + + Parameters + ---------- + logical_id : str + Logical ID of a resource. + api_call : str + API call the resource will use. + + Returns + ------- + str + String key created with logical ID and API call name. + """ + return logical_id + "_" + api_call + + def _get_lock_chain(self) -> LockChain: + """Return a LockChain object for all the locks + + Returns + ------- + Optional[LockChain] + A LockChain object containing all locks. None if there are no locks. + """ + if self._locks: + return LockChain(self._locks) + raise MissingLockException("Missing Locks for LockChain") + + def _get_resource(self, resource_identifier: str) -> Optional[Dict[str, Any]]: + """Get a resource dict with resource identifier + + Parameters + ---------- + resource_identifier : str + Resource identifier + + Returns + ------- + Optional[Dict[str, Any]] + Resource dict containing its template fields. + """ + return get_resource_by_id(self._stacks, ResourceIdentifier(resource_identifier)) if self._stacks else None + + def get_physical_id(self, resource_identifier: str) -> str: + """Get the physical ID of a resource using physical_id_mapping. This does not directly check with remote. + + Parameters + ---------- + resource_identifier : str + Resource identifier + + Returns + ------- + str + Resource physical ID + + Raises + ------ + MissingPhysicalResourceError + Resource does not exist in the physical ID mapping. + This could mean remote and local templates are not in sync. + """ + physical_id = self._physical_id_mapping.get(resource_identifier) + if not physical_id: + raise MissingPhysicalResourceError(resource_identifier) + + return physical_id + + @abstractmethod + def _equality_keys(self) -> Any: + """This method needs to be overridden to distinguish between multiple instances of SyncFlows + If the return values of two instances are the same, then those two instances will be assumed to be equal. + + Returns + ------- + Any + Anything that can be hashed and compared with "==" + """ + raise NotImplementedError("_equality_keys is not implemented.") + + def __hash__(self) -> int: + return hash((type(self), self._equality_keys())) + + def __eq__(self, o: object) -> bool: + if type(o) is not type(self): + return False + return cast(bool, self._equality_keys() == cast(SyncFlow, o)._equality_keys()) + + @property + def log_name(self) -> str: + """ + Returns + ------- + str + Human readable name/identifier for logging purposes + """ + return self._log_name + + @property + def log_prefix(self) -> str: + """ + Returns + ------- + str + Log prefix to be used for logging. + """ + return f"SyncFlow [{self.log_name}]: " + + def execute(self) -> List["SyncFlow"]: + """Execute the sync flow and returns a list of dependent sync flows. + Skips sync() and gather_dependencies() if compare() is True + + Returns + ------- + List[SyncFlow] + A list of dependent sync flows + """ + dependencies: List["SyncFlow"] = list() + LOG.debug("%sSetting Up", self.log_prefix) + self.set_up() + LOG.debug("%sGathering Resources", self.log_prefix) + self.gather_resources() + LOG.debug("%sComparing with Remote", self.log_prefix) + if not self.compare_remote(): + LOG.debug("%sSyncing", self.log_prefix) + self.sync() + LOG.debug("%sGathering Dependencies", self.log_prefix) + dependencies = self.gather_dependencies() + LOG.debug("%sFinished", self.log_prefix) + return dependencies diff --git a/samcli/lib/sync/sync_flow_executor.py b/samcli/lib/sync/sync_flow_executor.py new file mode 100644 index 0000000000..0cba6305cd --- /dev/null +++ b/samcli/lib/sync/sync_flow_executor.py @@ -0,0 +1,342 @@ +"""Executor for SyncFlows""" +import logging +import time + +from queue import Queue +from typing import Callable, List, Optional, Set +from dataclasses import dataclass + +from threading import RLock +from concurrent.futures import ThreadPoolExecutor, Future + +from botocore.exceptions import ClientError + +from samcli.lib.utils.colors import Colored +from samcli.lib.providers.exceptions import MissingLocalDefinition +from samcli.lib.sync.exceptions import ( + InfraSyncRequiredError, + MissingPhysicalResourceError, + NoLayerVersionsFoundError, + SyncFlowException, + MissingFunctionBuildDefinition, + InvalidRuntimeDefinitionForFunction, +) + +from samcli.lib.utils.lock_distributor import LockDistributor, LockDistributorType +from samcli.lib.sync.sync_flow import SyncFlow + +LOG = logging.getLogger(__name__) + +HELP_TEXT_FOR_SYNC_INFRA = " Try sam sync without --code or sam deploy." + + +@dataclass(frozen=True, eq=True) +class SyncFlowTask: + """Data struct for individual SyncFlow execution tasks""" + + # SyncFlow to be executed + sync_flow: SyncFlow + + # Should this task be ignored if there is a sync flow in the queue that's the same + dedup: bool + + +@dataclass(frozen=True, eq=True) +class SyncFlowResult: + """Data struct for SyncFlow results""" + + sync_flow: SyncFlow + dependent_sync_flows: List[SyncFlow] + + +@dataclass(frozen=True, eq=True) +class SyncFlowFuture: + """Data struct for SyncFlow futures""" + + sync_flow: SyncFlow + future: Future + + +def default_exception_handler(sync_flow_exception: SyncFlowException) -> None: + """Default exception handler for SyncFlowExecutor + This will try log and parse common SyncFlow exceptions. + + Parameters + ---------- + sync_flow_exception : SyncFlowException + SyncFlowException containing exception to be handled and SyncFlow that raised it + + Raises + ------ + exception + Unhandled exception + """ + exception = sync_flow_exception.exception + if isinstance(exception, MissingPhysicalResourceError): + LOG.error("Cannot find resource %s in remote.%s", exception.resource_identifier, HELP_TEXT_FOR_SYNC_INFRA) + elif isinstance(exception, InfraSyncRequiredError): + LOG.error( + "Cannot code sync for %s due to: %s.%s", + exception.resource_identifier, + exception.reason, + HELP_TEXT_FOR_SYNC_INFRA, + ) + elif ( + isinstance(exception, ClientError) + and exception.response.get("Error", dict()).get("Code", "") == "ResourceNotFoundException" + ): + LOG.error("Cannot find resource in remote.%s", HELP_TEXT_FOR_SYNC_INFRA) + LOG.error(exception.response.get("Error", dict()).get("Message", "")) + elif isinstance(exception, NoLayerVersionsFoundError): + LOG.error("Cannot find any versions for layer %s.%s", exception.layer_name_arn, HELP_TEXT_FOR_SYNC_INFRA) + elif isinstance(exception, MissingFunctionBuildDefinition): + LOG.error( + "Cannot find build definition for function %s.%s", exception.function_logical_id, HELP_TEXT_FOR_SYNC_INFRA + ) + elif isinstance(exception, InvalidRuntimeDefinitionForFunction): + LOG.error("No Runtime information found for function resource named %s", exception.function_logical_id) + elif isinstance(exception, MissingLocalDefinition): + LOG.error( + "Resource %s does not have %s specified. Skipping the sync.%s", + exception.resource_identifier, + exception.property_name, + HELP_TEXT_FOR_SYNC_INFRA, + ) + else: + raise exception + + +class SyncFlowExecutor: + """Executor for SyncFlows + Can be used with ThreadPoolExecutor or ProcessPoolExecutor with/without manager + """ + + _flow_queue: Queue + _flow_queue_lock: RLock + _lock_distributor: LockDistributor + _running_flag: bool + _color: Colored + _running_futures: Set[SyncFlowFuture] + + def __init__( + self, + ) -> None: + self._flow_queue = Queue() + self._lock_distributor = LockDistributor(LockDistributorType.THREAD) + self._running_flag = False + self._flow_queue_lock = RLock() + self._color = Colored() + self._running_futures = set() + + def _add_sync_flow_task(self, task: SyncFlowTask) -> None: + """Add SyncFlowTask to the queue + + Parameters + ---------- + task : SyncFlowTask + SyncFlowTask to be added. + """ + # Lock flow_queue as check dedup and add is not atomic + with self._flow_queue_lock: + if task.dedup and task.sync_flow in [task.sync_flow for task in self._flow_queue.queue]: + LOG.debug("Found the same SyncFlow in queue. Skip adding.") + return + + task.sync_flow.set_locks_with_distributor(self._lock_distributor) + self._flow_queue.put(task) + + def add_sync_flow(self, sync_flow: SyncFlow, dedup: bool = True) -> None: + """Add a SyncFlow to queue to be executed + Locks will be set with LockDistributor + + Parameters + ---------- + sync_flow : SyncFlow + SyncFlow to be executed + dedup : bool + SyncFlow will not be added if this flag is True and has a duplicate in the queue + """ + self._add_sync_flow_task(SyncFlowTask(sync_flow, dedup)) + + def is_running(self) -> bool: + """ + Returns + ------- + bool + Is executor running + """ + return self._running_flag + + def _can_exit(self) -> bool: + """ + Returns + ------- + bool + Can executor be safely exited + """ + return not self._running_futures and self._flow_queue.empty() + + def execute( + self, exception_handler: Optional[Callable[[SyncFlowException], None]] = default_exception_handler + ) -> None: + """Blocking execution of the SyncFlows + + Parameters + ---------- + exception_handler : Optional[Callable[[Exception], None]], optional + Function to be called if an exception is raised during the execution of a SyncFlow, + by default default_exception_handler.__func__ + """ + self._running_flag = True + with ThreadPoolExecutor() as executor: + self._running_futures.clear() + while True: + + self._execute_step(executor, exception_handler) + + # Exit execution if there are no running and pending sync flows + if self._can_exit(): + LOG.debug("No more SyncFlows in executor. Stopping.") + break + + # Sleep for a bit to cut down CPU utilization of this busy wait loop + time.sleep(0.1) + self._running_flag = False + + def _execute_step( + self, + executor: ThreadPoolExecutor, + exception_handler: Optional[Callable[[SyncFlowException], None]], + ) -> None: + """A single step in the execution flow + + Parameters + ---------- + executor : ThreadPoolExecutor + THreadPoolExecutor to be used for execution + exception_handler : Optional[Callable[[SyncFlowException], None]] + Exception handler + """ + # Execute all pending sync flows + with self._flow_queue_lock: + # Putting nonsubmitted tasks into this deferred tasks list + # to avoid modifying the queue while emptying it + deferred_tasks = list() + + # Go through all queued tasks and try to execute them + while not self._flow_queue.empty(): + sync_flow_task: SyncFlowTask = self._flow_queue.get() + + sync_flow_future = self._submit_sync_flow_task(executor, sync_flow_task) + + # sync_flow_future can be None if the task cannot be submitted currently + # Put it into deferred_tasks and add all of them at the end to avoid endless loop + if sync_flow_future: + self._running_futures.add(sync_flow_future) + LOG.info(self._color.cyan(f"Syncing {sync_flow_future.sync_flow.log_name}...")) + else: + deferred_tasks.append(sync_flow_task) + + # Add tasks that cannot be executed yet + for task in deferred_tasks: + self._add_sync_flow_task(task) + + # Check for finished sync flows + for sync_flow_future in set(self._running_futures): + if self._handle_result(sync_flow_future, exception_handler): + self._running_futures.remove(sync_flow_future) + + def _submit_sync_flow_task( + self, executor: ThreadPoolExecutor, sync_flow_task: SyncFlowTask + ) -> Optional[SyncFlowFuture]: + """Submit SyncFlowTask to be executed by ThreadPoolExecutor + and return its future + + Parameters + ---------- + executor : ThreadPoolExecutor + THreadPoolExecutor to be used for execution + sync_flow_task : SyncFlowTask + SyncFlowTask to be executed. + + Returns + ------- + Optional[SyncFlowFuture] + Returns SyncFlowFuture generated by the SyncFlowTask. + Can be None if the task cannot be executed yet. + """ + sync_flow = sync_flow_task.sync_flow + + # Check whether the same sync flow is already running or not + if sync_flow in [future.sync_flow for future in self._running_futures]: + return None + + sync_flow_future = SyncFlowFuture( + sync_flow=sync_flow, future=executor.submit(SyncFlowExecutor._sync_flow_execute_wrapper, sync_flow) + ) + + return sync_flow_future + + def _handle_result( + self, sync_flow_future: SyncFlowFuture, exception_handler: Optional[Callable[[SyncFlowException], None]] + ) -> bool: + """Checks and handles the result of a SyncFlowFuture + + Parameters + ---------- + sync_flow_future : SyncFlowFuture + The SyncFlowFuture that needs to be handled + exception_handler : Optional[Callable[[SyncFlowException], None]] + Exception handler that will be called if an exception is raised within the SyncFlow + + Returns + ------- + bool + Returns True if the SyncFlowFuture was finished and successfully handled, False otherwise. + """ + future = sync_flow_future.future + + if not future.done(): + return False + + exception = future.exception() + + if exception and isinstance(exception, SyncFlowException) and exception_handler: + # Exception handling + exception_handler(exception) + else: + # Add dependency sync flows to queue + sync_flow_result: SyncFlowResult = future.result() + for dependent_sync_flow in sync_flow_result.dependent_sync_flows: + self.add_sync_flow(dependent_sync_flow) + LOG.info(self._color.green(f"Finished syncing {sync_flow_result.sync_flow.log_name}.")) + return True + + @staticmethod + def _sync_flow_execute_wrapper(sync_flow: SyncFlow) -> SyncFlowResult: + """Simple wrapper method for executing SyncFlow and converting all Exceptions into SyncFlowException + + Parameters + ---------- + sync_flow : SyncFlow + SyncFlow to be executed + + Returns + ------- + SyncFlowResult + SyncFlowResult for the SyncFlow executed + + Raises + ------ + SyncFlowException + """ + dependent_sync_flows = [] + try: + dependent_sync_flows = sync_flow.execute() + except ClientError as e: + if e.response.get("Error", dict()).get("Code", "") == "ResourceNotFoundException": + raise SyncFlowException(sync_flow, MissingPhysicalResourceError()) from e + raise SyncFlowException(sync_flow, e) from e + except Exception as e: + raise SyncFlowException(sync_flow, e) from e + return SyncFlowResult(sync_flow=sync_flow, dependent_sync_flows=dependent_sync_flows) diff --git a/samcli/lib/sync/sync_flow_factory.py b/samcli/lib/sync/sync_flow_factory.py new file mode 100644 index 0000000000..0da6608d1c --- /dev/null +++ b/samcli/lib/sync/sync_flow_factory.py @@ -0,0 +1,181 @@ +"""SyncFlow Factory for creating SyncFlows based on resource types""" +import logging +from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, cast + +from samcli.lib.bootstrap.nested_stack.nested_stack_manager import NestedStackManager +from samcli.lib.providers.provider import Stack, get_resource_by_id, ResourceIdentifier +from samcli.lib.sync.flows.auto_dependency_layer_sync_flow import AutoDependencyLayerParentSyncFlow +from samcli.lib.sync.flows.layer_sync_flow import LayerSyncFlow +from samcli.lib.utils.packagetype import ZIP, IMAGE +from samcli.lib.utils.resource_type_based_factory import ResourceTypeBasedFactory + +from samcli.lib.sync.sync_flow import SyncFlow +from samcli.lib.sync.flows.function_sync_flow import FunctionSyncFlow +from samcli.lib.sync.flows.zip_function_sync_flow import ZipFunctionSyncFlow +from samcli.lib.sync.flows.image_function_sync_flow import ImageFunctionSyncFlow +from samcli.lib.sync.flows.rest_api_sync_flow import RestApiSyncFlow +from samcli.lib.sync.flows.http_api_sync_flow import HttpApiSyncFlow +from samcli.lib.sync.flows.stepfunctions_sync_flow import StepFunctionsSyncFlow +from samcli.lib.utils.boto_utils import get_boto_resource_provider_with_config +from samcli.lib.utils.cloudformation import get_physical_id_mapping +from samcli.lib.utils.resources import ( + AWS_SERVERLESS_FUNCTION, + AWS_LAMBDA_FUNCTION, + AWS_SERVERLESS_LAYERVERSION, + AWS_LAMBDA_LAYERVERSION, + AWS_SERVERLESS_API, + AWS_APIGATEWAY_RESTAPI, + AWS_SERVERLESS_HTTPAPI, + AWS_APIGATEWAY_V2_API, + AWS_SERVERLESS_STATEMACHINE, + AWS_STEPFUNCTIONS_STATEMACHINE, +) + +if TYPE_CHECKING: # pragma: no cover + from samcli.commands.deploy.deploy_context import DeployContext + from samcli.commands.build.build_context import BuildContext + +LOG = logging.getLogger(__name__) + + +class SyncFlowFactory(ResourceTypeBasedFactory[SyncFlow]): # pylint: disable=E1136 + """Factory class for SyncFlow + Creates appropriate SyncFlow types based on stack resource types + """ + + _deploy_context: "DeployContext" + _build_context: "BuildContext" + _physical_id_mapping: Dict[str, str] + _auto_dependency_layer: bool + + def __init__( + self, + build_context: "BuildContext", + deploy_context: "DeployContext", + stacks: List[Stack], + auto_dependency_layer: bool, + ) -> None: + """ + Parameters + ---------- + build_context : BuildContext + BuildContext to be passed into each individual SyncFlow + deploy_context : DeployContext + DeployContext to be passed into each individual SyncFlow + stacks : List[Stack] + List of stacks containing a root stack and optional nested ones + """ + super().__init__(stacks) + self._deploy_context = deploy_context + self._build_context = build_context + self._auto_dependency_layer = auto_dependency_layer + self._physical_id_mapping = dict() + + def load_physical_id_mapping(self) -> None: + """Load physical IDs of the stack resources from remote""" + LOG.debug("Loading physical ID mapping") + self._physical_id_mapping = get_physical_id_mapping( + get_boto_resource_provider_with_config( + region=self._deploy_context.region, + profile=self._deploy_context.profile, + ), + self._deploy_context.stack_name, + ) + + def _create_lambda_flow( + self, resource_identifier: ResourceIdentifier, resource: Dict[str, Any] + ) -> Optional[FunctionSyncFlow]: + resource_properties = resource.get("Properties", dict()) + package_type = resource_properties.get("PackageType", ZIP) + runtime = resource_properties.get("Runtime") + if package_type == ZIP: + # only return auto dependency layer sync if runtime is supported + if self._auto_dependency_layer and NestedStackManager.is_runtime_supported(runtime): + return AutoDependencyLayerParentSyncFlow( + str(resource_identifier), + self._build_context, + self._deploy_context, + self._physical_id_mapping, + self._stacks, + ) + + return ZipFunctionSyncFlow( + str(resource_identifier), + self._build_context, + self._deploy_context, + self._physical_id_mapping, + self._stacks, + ) + if package_type == IMAGE: + return ImageFunctionSyncFlow( + str(resource_identifier), + self._build_context, + self._deploy_context, + self._physical_id_mapping, + self._stacks, + ) + return None + + def _create_layer_flow(self, resource_identifier: ResourceIdentifier, resource: Dict[str, Any]) -> SyncFlow: + return LayerSyncFlow( + str(resource_identifier), + self._build_context, + self._deploy_context, + self._physical_id_mapping, + self._stacks, + ) + + def _create_rest_api_flow(self, resource_identifier: ResourceIdentifier, resource: Dict[str, Any]) -> SyncFlow: + return RestApiSyncFlow( + str(resource_identifier), + self._build_context, + self._deploy_context, + self._physical_id_mapping, + self._stacks, + ) + + def _create_api_flow(self, resource_identifier: ResourceIdentifier, resource: Dict[str, Any]) -> SyncFlow: + return HttpApiSyncFlow( + str(resource_identifier), + self._build_context, + self._deploy_context, + self._physical_id_mapping, + self._stacks, + ) + + def _create_stepfunctions_flow( + self, resource_identifier: ResourceIdentifier, resource: Dict[str, Any] + ) -> Optional[SyncFlow]: + return StepFunctionsSyncFlow( + str(resource_identifier), + self._build_context, + self._deploy_context, + self._physical_id_mapping, + self._stacks, + ) + + GeneratorFunction = Callable[["SyncFlowFactory", ResourceIdentifier, Dict[str, Any]], Optional[SyncFlow]] + GENERATOR_MAPPING: Dict[str, GeneratorFunction] = { + AWS_LAMBDA_FUNCTION: _create_lambda_flow, + AWS_SERVERLESS_FUNCTION: _create_lambda_flow, + AWS_SERVERLESS_LAYERVERSION: _create_layer_flow, + AWS_LAMBDA_LAYERVERSION: _create_layer_flow, + AWS_SERVERLESS_API: _create_rest_api_flow, + AWS_APIGATEWAY_RESTAPI: _create_rest_api_flow, + AWS_SERVERLESS_HTTPAPI: _create_api_flow, + AWS_APIGATEWAY_V2_API: _create_api_flow, + AWS_SERVERLESS_STATEMACHINE: _create_stepfunctions_flow, + AWS_STEPFUNCTIONS_STATEMACHINE: _create_stepfunctions_flow, + } + + # SyncFlow mapping between resource type and creation function + # Ignoring no-self-use as PyLint has a bug with Generic Abstract Classes + def _get_generator_mapping(self) -> Dict[str, GeneratorFunction]: # pylint: disable=no-self-use + return SyncFlowFactory.GENERATOR_MAPPING + + def create_sync_flow(self, resource_identifier: ResourceIdentifier) -> Optional[SyncFlow]: + resource = get_resource_by_id(self._stacks, resource_identifier) + generator = self._get_generator_function(resource_identifier) + if not generator or not resource: + return None + return cast(SyncFlowFactory.GeneratorFunction, generator)(self, resource_identifier, resource) diff --git a/samcli/lib/sync/watch_manager.py b/samcli/lib/sync/watch_manager.py new file mode 100644 index 0000000000..41d68eb93a --- /dev/null +++ b/samcli/lib/sync/watch_manager.py @@ -0,0 +1,241 @@ +""" +WatchManager for Sync Watch Logic +""" +import logging +import time +import threading + +from typing import List, Optional, TYPE_CHECKING + +from samcli.lib.utils.colors import Colored +from samcli.lib.providers.exceptions import MissingCodeUri, MissingLocalDefinition + +from samcli.lib.providers.provider import ResourceIdentifier, Stack, get_all_resource_ids +from samcli.lib.utils.code_trigger_factory import CodeTriggerFactory +from samcli.lib.providers.sam_stack_provider import SamLocalStackProvider +from samcli.lib.utils.path_observer import HandlerObserver + +from samcli.lib.sync.sync_flow_factory import SyncFlowFactory +from samcli.lib.sync.exceptions import InfraSyncRequiredError, MissingPhysicalResourceError, SyncFlowException +from samcli.lib.utils.resource_trigger import OnChangeCallback, TemplateTrigger +from samcli.lib.sync.continuous_sync_flow_executor import ContinuousSyncFlowExecutor + +if TYPE_CHECKING: # pragma: no cover + from samcli.commands.deploy.deploy_context import DeployContext + from samcli.commands.package.package_context import PackageContext + from samcli.commands.build.build_context import BuildContext + +DEFAULT_WAIT_TIME = 1 +LOG = logging.getLogger(__name__) + + +class WatchManager: + _stacks: Optional[List[Stack]] + _template: str + _build_context: "BuildContext" + _package_context: "PackageContext" + _deploy_context: "DeployContext" + _sync_flow_factory: Optional[SyncFlowFactory] + _sync_flow_executor: ContinuousSyncFlowExecutor + _executor_thread: Optional[threading.Thread] + _observer: HandlerObserver + _trigger_factory: Optional[CodeTriggerFactory] + _waiting_infra_sync: bool + _color: Colored + _auto_dependency_layer: bool + + def __init__( + self, + template: str, + build_context: "BuildContext", + package_context: "PackageContext", + deploy_context: "DeployContext", + auto_dependency_layer: bool, + ): + """Manager for sync watch execution logic. + This manager will observe template and its code resources. + Automatically execute infra/code syncs when changes are detected. + + Parameters + ---------- + template : str + Template file path + build_context : BuildContext + BuildContext + package_context : PackageContext + PackageContext + deploy_context : DeployContext + DeployContext + """ + self._stacks = None + self._template = template + self._build_context = build_context + self._package_context = package_context + self._deploy_context = deploy_context + self._auto_dependency_layer = auto_dependency_layer + + self._sync_flow_factory = None + self._sync_flow_executor = ContinuousSyncFlowExecutor() + self._executor_thread = None + + self._observer = HandlerObserver() + self._trigger_factory = None + + self._waiting_infra_sync = False + self._color = Colored() + + def queue_infra_sync(self) -> None: + """Queue up an infra structure sync. + A simple bool flag is suffice + """ + self._waiting_infra_sync = True + + def _update_stacks(self) -> None: + """ + Reloads template and its stacks. + Update all other member that also depends on the stacks. + This should be called whenever there is a change to the template. + """ + self._stacks = SamLocalStackProvider.get_stacks(self._template)[0] + self._sync_flow_factory = SyncFlowFactory( + self._build_context, self._deploy_context, self._stacks, self._auto_dependency_layer + ) + self._sync_flow_factory.load_physical_id_mapping() + self._trigger_factory = CodeTriggerFactory(self._stacks) + + def _add_code_triggers(self) -> None: + """Create CodeResourceTrigger for all resources and add their handlers to observer""" + if not self._stacks or not self._trigger_factory: + return + resource_ids = get_all_resource_ids(self._stacks) + for resource_id in resource_ids: + try: + trigger = self._trigger_factory.create_trigger(resource_id, self._on_code_change_wrapper(resource_id)) + except (MissingCodeUri, MissingLocalDefinition): + LOG.debug("CodeTrigger not created as CodeUri or DefinitionUri is missing for %s.", str(resource_id)) + continue + + if not trigger: + continue + self._observer.schedule_handlers(trigger.get_path_handlers()) + + def _add_template_trigger(self) -> None: + """Create TemplateTrigger and add its handlers to observer""" + template_trigger = TemplateTrigger(self._template, lambda _=None: self.queue_infra_sync()) + self._observer.schedule_handlers(template_trigger.get_path_handlers()) + + def _execute_infra_context(self) -> None: + """Execute infrastructure sync""" + self._build_context.set_up() + self._build_context.run() + self._package_context.run() + self._deploy_context.run() + + def _start_code_sync(self) -> None: + """Start SyncFlowExecutor in a separate thread.""" + if not self._executor_thread or not self._executor_thread.is_alive(): + self._executor_thread = threading.Thread( + target=lambda: self._sync_flow_executor.execute( + exception_handler=self._watch_sync_flow_exception_handler + ) + ) + self._executor_thread.start() + + def _stop_code_sync(self) -> None: + """Blocking call that stops SyncFlowExecutor and waits for it to finish.""" + if self._executor_thread and self._executor_thread.is_alive(): + self._sync_flow_executor.stop() + self._executor_thread.join() + + def start(self) -> None: + """Start WatchManager and watch for changes to the template and its code resources.""" + + # The actual execution is done in _start() + # This is a wrapper for gracefully handling Ctrl+C or other termination cases. + try: + self.queue_infra_sync() + self._start() + except KeyboardInterrupt: + LOG.info(self._color.cyan("Shutting down sync watch...")) + self._observer.stop() + self._stop_code_sync() + LOG.info(self._color.green("Sync watch stopped.")) + + def _start(self) -> None: + """Start WatchManager and watch for changes to the template and its code resources.""" + self._observer.start() + while True: + if self._waiting_infra_sync: + self._execute_infra_sync() + time.sleep(1) + + def _execute_infra_sync(self) -> None: + LOG.info(self._color.cyan("Queued infra sync. Wating for in progress code syncs to complete...")) + self._waiting_infra_sync = False + self._stop_code_sync() + try: + LOG.info(self._color.cyan("Starting infra sync.")) + self._execute_infra_context() + except Exception as e: + LOG.error( + self._color.red("Failed to sync infra. Code sync is paused until template/stack is fixed."), + exc_info=e, + ) + # Unschedule all triggers and only add back the template one as infra sync is incorrect. + self._observer.unschedule_all() + self._add_template_trigger() + else: + # Update stacks and repopulate triggers + # Trigger are not removed until infra sync is finished as there + # can be code changes during infra sync. + self._observer.unschedule_all() + self._update_stacks() + self._add_template_trigger() + self._add_code_triggers() + self._start_code_sync() + LOG.info(self._color.green("Infra sync completed.")) + + def _on_code_change_wrapper(self, resource_id: ResourceIdentifier) -> OnChangeCallback: + """Wrapper method that generates a callback for code changes. + + Parameters + ---------- + resource_id : ResourceIdentifier + Resource that associates to the callback + + Returns + ------- + OnChangeCallback + Callback function + """ + + def on_code_change(_=None): + sync_flow = self._sync_flow_factory.create_sync_flow(resource_id) + if sync_flow and not self._waiting_infra_sync: + self._sync_flow_executor.add_delayed_sync_flow(sync_flow, dedup=True, wait_time=DEFAULT_WAIT_TIME) + + return on_code_change + + def _watch_sync_flow_exception_handler(self, sync_flow_exception: SyncFlowException) -> None: + """Exception handler for watch. + Simply logs unhandled exceptions instead of failing the entire process. + + Parameters + ---------- + sync_flow_exception : SyncFlowException + SyncFlowException + """ + exception = sync_flow_exception.exception + if isinstance(exception, MissingPhysicalResourceError): + LOG.warning(self._color.yellow("Missing physical resource. Infra sync will be started.")) + self.queue_infra_sync() + elif isinstance(exception, InfraSyncRequiredError): + LOG.warning( + self._color.yellow( + f"Infra sync is required for {exception.resource_identifier} due to: " + + f"{exception.reason}. Infra sync will be started." + ) + ) + self.queue_infra_sync() + else: + LOG.error(self._color.red("Code sync encountered an error."), exc_info=exception) diff --git a/samcli/lib/telemetry/metric.py b/samcli/lib/telemetry/metric.py index f97b9f0e01..75076c5f1a 100644 --- a/samcli/lib/telemetry/metric.py +++ b/samcli/lib/telemetry/metric.py @@ -17,6 +17,7 @@ from samcli.lib.warnings.sam_cli_warning import TemplateWarningsChecker from samcli.commands.exceptions import UserException from samcli.lib.telemetry.cicd import CICDDetector, CICDPlatform +from samcli.commands._utils.experimental import get_all_experimental_statues from .telemetry import Telemetry LOG = logging.getLogger(__name__) @@ -25,7 +26,7 @@ """ Global variables are evil but this is a justified usage. -This creates a versitile telemetry tracking no matter where in the code. Something like a Logger. +This creates a versatile telemetry tracking no matter where in the code. Something like a Logger. No side effect will result in this as it is write-only for code outside of telemetry. Decorators should be used to minimize logic involving telemetry. """ @@ -107,7 +108,6 @@ def hello_command(): def wrapped(*args, **kwargs): telemetry = Telemetry() - metric = Metric("commandRun") exception = None return_value = None @@ -138,10 +138,14 @@ def wrapped(*args, **kwargs): try: ctx = Context.get_current_context() + metric_name = "commandRunExperimental" if ctx.experimental else "commandRun" + metric = Metric(metric_name) metric.add_data("awsProfileProvided", bool(ctx.profile)) metric.add_data("debugFlagProvided", bool(ctx.debug)) metric.add_data("region", ctx.region or "") metric.add_data("commandName", ctx.command_path) # Full command path. ex: sam local start-api + if ctx.experimental: + metric.add_data("metricSpecificAttributes", get_all_experimental_statues()) # Metric about command's execution characteristics metric.add_data("duration", duration_fn()) metric.add_data("exitReason", exit_reason) @@ -238,7 +242,7 @@ def wrapped_func(*args, **kwargs): def capture_return_value(metric_name, key, as_list=False): """ - Decorator for capturing the reutrn value of the function. + Decorator for capturing the return value of the function. :param metric_name Name of the metric :param key Key for storing the captured parameter diff --git a/samcli/lib/utils/architecture.py b/samcli/lib/utils/architecture.py index b2b73231f4..d7a05e1761 100644 --- a/samcli/lib/utils/architecture.py +++ b/samcli/lib/utils/architecture.py @@ -7,7 +7,7 @@ from samcli.commands.local.lib.exceptions import UnsupportedRuntimeArchitectureError from samcli.lib.utils.packagetype import IMAGE -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from samcli.lib.providers.provider import Function X86_64 = "x86_64" diff --git a/samcli/lib/utils/boto_utils.py b/samcli/lib/utils/boto_utils.py new file mode 100644 index 0000000000..6091c8f857 --- /dev/null +++ b/samcli/lib/utils/boto_utils.py @@ -0,0 +1,96 @@ +""" +This module contains utility functions for boto3 library +""" +from typing import Any, Optional +from typing_extensions import Protocol + +import boto3 +from botocore.config import Config + +from samcli import __version__ +from samcli.cli.global_config import GlobalConfig + + +def get_boto_config_with_user_agent(**kwargs) -> Config: + """ + Automatically add user agent string to boto configs. + + Parameters + ---------- + kwargs : + key=value params which will be added to the Config object + + Returns + ------- + Config + Returns config instance which contains given parameters in it + """ + gc = GlobalConfig() + return Config( + user_agent_extra=f"aws-sam-cli/{__version__}/{gc.installation_id}" + if gc.telemetry_enabled + else f"aws-sam-cli/{__version__}", + **kwargs, + ) + + +# Type definition of following boto providers, which is equal to Callable[[str], Any] +class BotoProviderType(Protocol): + def __call__(self, service_name: str) -> Any: + ... + + +def get_boto_client_provider_with_config( + region: Optional[str] = None, profile: Optional[str] = None, **kwargs +) -> BotoProviderType: + """ + Returns a wrapper function for boto client with given configuration. It can be used like; + + client_provider = get_boto_client_wrapper_with_config(region_name=region) + lambda_client = client_provider("lambda") + + Parameters + ---------- + region: Optional[str] + AWS region name + profile: Optional[str] + Profile name from credentials + kwargs : + Key-value params that will be passed to get_boto_config_with_user_agent + + Returns + ------- + A callable function which will return a boto client + """ + # ignore typing because mypy tries to assert client_name with a valid service name + return lambda client_name: boto3.session.Session(region_name=region, profile_name=profile).client( # type: ignore + client_name, config=get_boto_config_with_user_agent(**kwargs) + ) + + +def get_boto_resource_provider_with_config( + region: Optional[str] = None, profile: Optional[str] = None, **kwargs +) -> BotoProviderType: + """ + Returns a wrapper function for boto resource with given configuration. It can be used like; + + resource_provider = get_boto_resource_wrapper_with_config(region_name=region) + cloudformation_resource = resource_provider("cloudformation") + + Parameters + ---------- + region: Optional[str] + AWS region name + profile: Optional[str] + Profile name from credentials + kwargs : + Key-value params that will be passed to get_boto_config_with_user_agent + + Returns + ------- + A callable function which will return a boto resource + """ + # ignore typing because mypy tries to assert client_name with a valid service name + return lambda resource_name: boto3.session.Session( + region_name=region, profile_name=profile # type: ignore + ).resource(resource_name, config=get_boto_config_with_user_agent(**kwargs)) diff --git a/samcli/lib/utils/botoconfig.py b/samcli/lib/utils/botoconfig.py deleted file mode 100644 index 7a7bd6d792..0000000000 --- a/samcli/lib/utils/botoconfig.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -Automatically add user agent string to boto configs. -""" -from botocore.config import Config - -from samcli import __version__ -from samcli.cli.global_config import GlobalConfig - - -def get_boto_config_with_user_agent(**kwargs): - gc = GlobalConfig() - return Config( - user_agent_extra=f"aws-sam-cli/{__version__}/{gc.installation_id}" - if gc.telemetry_enabled - else f"aws-sam-cli/{__version__}", - **kwargs, - ) diff --git a/samcli/lib/utils/cloudformation.py b/samcli/lib/utils/cloudformation.py new file mode 100644 index 0000000000..b6590550d4 --- /dev/null +++ b/samcli/lib/utils/cloudformation.py @@ -0,0 +1,127 @@ +""" +This utility file contains methods to read information from certain CFN stack +""" +import logging +from typing import List, Dict, NamedTuple, Set, Optional + +from botocore.exceptions import ClientError + +from samcli.lib.utils.boto_utils import BotoProviderType + +LOG = logging.getLogger(__name__) + + +class CloudFormationResourceSummary(NamedTuple): + """ + Keeps information about CFN resource + """ + + resource_type: str + logical_resource_id: str + physical_resource_id: str + + +def get_physical_id_mapping( + boto_resource_provider: BotoProviderType, stack_name: str, resource_types: Optional[Set[str]] = None +) -> Dict[str, str]: + """ + Uses get_resource_summaries method to gather resource summaries and creates a dictionary which contains + logical_id to physical_id mapping + + Parameters + ---------- + boto_resource_provider : BotoProviderType + A callable which will return boto3 resource + stack_name : str + Name of the stack which is deployed to CFN + resource_types : Optional[Set[str]] + List of resource types, which will filter the results + + Returns + ------- + Dictionary of string, string which will contain logical_id to physical_id mapping + + """ + resource_summaries = get_resource_summaries(boto_resource_provider, stack_name, resource_types) + + resource_physical_id_map: Dict[str, str] = {} + for resource_summary in resource_summaries: + resource_physical_id_map[resource_summary.logical_resource_id] = resource_summary.physical_resource_id + + return resource_physical_id_map + + +def get_resource_summaries( + boto_resource_provider: BotoProviderType, stack_name: str, resource_types: Optional[Set[str]] = None +) -> List[CloudFormationResourceSummary]: + """ + Collects information about CFN resources and return their summary as list + + Parameters + ---------- + boto_resource_provider : BotoProviderType + A callable which will return boto3 resource + stack_name : str + Name of the stack which is deployed to CFN + resource_types : Optional[Set[str]] + List of resource types, which will filter the results + + Returns + ------- + List of CloudFormationResourceSummary which contains information about resources in the given stack + + """ + LOG.debug("Fetching stack (%s) resources", stack_name) + cfn_resource_summaries = boto_resource_provider("cloudformation").Stack(stack_name).resource_summaries.all() + resource_summaries: List[CloudFormationResourceSummary] = [] + + for cfn_resource_summary in cfn_resource_summaries: + resource_summary = CloudFormationResourceSummary( + cfn_resource_summary.resource_type, + cfn_resource_summary.logical_resource_id, + cfn_resource_summary.physical_resource_id, + ) + if resource_types and resource_summary.resource_type not in resource_types: + LOG.debug( + "Skipping resource %s since its type %s is not supported. Supported types %s", + resource_summary.logical_resource_id, + resource_summary.resource_type, + resource_types, + ) + continue + + resource_summaries.append(resource_summary) + + return resource_summaries + + +def get_resource_summary(boto_resource_provider: BotoProviderType, stack_name: str, resource_logical_id: str): + """ + Returns resource summary of given single resource with its logical id + + Parameters + ---------- + boto_resource_provider : BotoProviderType + A callable which will return boto3 resource + stack_name : str + Name of the stack which is deployed to CFN + resource_logical_id : str + Logical ID of the resource that will be returned as resource summary + + Returns + ------- + CloudFormationResourceSummary of the resource which is identified by given logical id + """ + try: + cfn_resource_summary = boto_resource_provider("cloudformation").StackResource(stack_name, resource_logical_id) + + return CloudFormationResourceSummary( + cfn_resource_summary.resource_type, + cfn_resource_summary.logical_resource_id, + cfn_resource_summary.physical_resource_id, + ) + except ClientError as e: + LOG.error( + "Failed to pull resource (%s) information from stack (%s)", resource_logical_id, stack_name, exc_info=e + ) + return None diff --git a/samcli/lib/utils/code_trigger_factory.py b/samcli/lib/utils/code_trigger_factory.py new file mode 100644 index 0000000000..f577ace803 --- /dev/null +++ b/samcli/lib/utils/code_trigger_factory.py @@ -0,0 +1,112 @@ +""" +Factory for creating CodeResourceTriggers +""" +import logging +from typing import Any, Callable, Dict, List, Optional, cast + +from samcli.lib.providers.provider import ResourceIdentifier, Stack, get_resource_by_id +from samcli.lib.utils.packagetype import IMAGE, ZIP +from samcli.lib.utils.resource_trigger import ( + CodeResourceTrigger, + DefinitionCodeTrigger, + LambdaImageCodeTrigger, + LambdaLayerCodeTrigger, + LambdaZipCodeTrigger, +) +from samcli.lib.utils.resource_type_based_factory import ResourceTypeBasedFactory +from samcli.lib.utils.resources import ( + AWS_APIGATEWAY_RESTAPI, + AWS_APIGATEWAY_V2_API, + AWS_LAMBDA_FUNCTION, + AWS_LAMBDA_LAYERVERSION, + AWS_SERVERLESS_API, + AWS_SERVERLESS_FUNCTION, + AWS_SERVERLESS_HTTPAPI, + AWS_SERVERLESS_LAYERVERSION, + AWS_SERVERLESS_STATEMACHINE, + AWS_STEPFUNCTIONS_STATEMACHINE, +) + +LOG = logging.getLogger(__name__) + + +class CodeTriggerFactory(ResourceTypeBasedFactory[CodeResourceTrigger]): # pylint: disable=E1136 + _stacks: List[Stack] + + def _create_lambda_trigger( + self, + resource_identifier: ResourceIdentifier, + resource_type: str, + resource: Dict[str, Any], + on_code_change: Callable, + ): + package_type = resource.get("Properties", dict()).get("PackageType", ZIP) + if package_type == ZIP: + return LambdaZipCodeTrigger(resource_identifier, self._stacks, on_code_change) + if package_type == IMAGE: + return LambdaImageCodeTrigger(resource_identifier, self._stacks, on_code_change) + return None + + def _create_layer_trigger( + self, + resource_identifier: ResourceIdentifier, + resource_type: str, + resource: Dict[str, Any], + on_code_change: Callable, + ): + return LambdaLayerCodeTrigger(resource_identifier, self._stacks, on_code_change) + + def _create_definition_code_trigger( + self, + resource_identifier: ResourceIdentifier, + resource_type: str, + resource: Dict[str, Any], + on_code_change: Callable, + ): + return DefinitionCodeTrigger(resource_identifier, resource_type, self._stacks, on_code_change) + + GeneratorFunction = Callable[ + ["CodeTriggerFactory", ResourceIdentifier, str, Dict[str, Any], Callable], Optional[CodeResourceTrigger] + ] + GENERATOR_MAPPING: Dict[str, GeneratorFunction] = { + AWS_LAMBDA_FUNCTION: _create_lambda_trigger, + AWS_SERVERLESS_FUNCTION: _create_lambda_trigger, + AWS_SERVERLESS_LAYERVERSION: _create_layer_trigger, + AWS_LAMBDA_LAYERVERSION: _create_layer_trigger, + AWS_SERVERLESS_API: _create_definition_code_trigger, + AWS_APIGATEWAY_RESTAPI: _create_definition_code_trigger, + AWS_SERVERLESS_HTTPAPI: _create_definition_code_trigger, + AWS_APIGATEWAY_V2_API: _create_definition_code_trigger, + AWS_SERVERLESS_STATEMACHINE: _create_definition_code_trigger, + AWS_STEPFUNCTIONS_STATEMACHINE: _create_definition_code_trigger, + } + + # Ignoring no-self-use as PyLint has a bug with Generic Abstract Classes + def _get_generator_mapping(self) -> Dict[str, GeneratorFunction]: # pylint: disable=no-self-use + return CodeTriggerFactory.GENERATOR_MAPPING + + def create_trigger( + self, resource_identifier: ResourceIdentifier, on_code_change: Callable + ) -> Optional[CodeResourceTrigger]: + """Create Trigger for the resource type + + Parameters + ---------- + resource_identifier : ResourceIdentifier + Resource associated with the trigger + on_code_change : Callable + Callback for code change + + Returns + ------- + Optional[CodeResourceTrigger] + CodeResourceTrigger for the resource + """ + resource = get_resource_by_id(self._stacks, resource_identifier) + generator = self._get_generator_function(resource_identifier) + resource_type = self._get_resource_type(resource_identifier) + if not generator or not resource or not resource_type: + return None + return cast(CodeTriggerFactory.GeneratorFunction, generator)( + self, resource_identifier, resource_type, resource, on_code_change + ) diff --git a/samcli/lib/utils/colors.py b/samcli/lib/utils/colors.py index 84767f0fec..e9b49b28d0 100644 --- a/samcli/lib/utils/colors.py +++ b/samcli/lib/utils/colors.py @@ -2,8 +2,18 @@ Wrapper to generated colored messages for printing in Terminal """ +import platform +import os + import click +# Enables ANSI escape codes on Windows +if platform.system().lower() == "windows": + try: + os.system("color") + except Exception: + pass + class Colored: """ diff --git a/samcli/lib/utils/definition_validator.py b/samcli/lib/utils/definition_validator.py new file mode 100644 index 0000000000..54d06f4101 --- /dev/null +++ b/samcli/lib/utils/definition_validator.py @@ -0,0 +1,60 @@ +"""DefinitionValidator for Validating YAML and JSON Files""" +import logging +from pathlib import Path +from typing import Any, Dict, Optional + +import yaml +from samcli.yamlhelper import parse_yaml_file + +LOG = logging.getLogger(__name__) + + +class DefinitionValidator: + _path: Path + _detect_change: bool + _data: Optional[Dict[str, Any]] + + def __init__(self, path: Path, detect_change: bool = True, initialize_data: bool = True) -> None: + """ + Validator for JSON and YAML files. + Calling validate() will return True if the definition is valid and + has changes. + + Parameters + ---------- + path : Path + Path to the definition file + detect_change : bool, optional + validation will only be successful if there are changes between current and previous data, + by default True + initialize_data : bool, optional + Should initialize existing definition data before the first validate, by default True + Used along with detect_change + """ + super().__init__() + self._path = path + self._detect_change = detect_change + self._data = None + if initialize_data: + self.validate() + + def validate(self) -> bool: + """Validate json or yaml file. + + Returns + ------- + bool + True if it is valid, False otherwise. + If detect_change is set, False will also be returned if there is + no change compared to the previous validation. + """ + if not self._path.exists(): + return False + + old_data = self._data + try: + self._data = parse_yaml_file(str(self._path)) + return old_data != self._data if self._detect_change else True + except (ValueError, yaml.YAMLError) as e: + LOG.debug("DefinitionValidator failed to validate.", exc_info=e) + return False diff --git a/samcli/lib/utils/hash.py b/samcli/lib/utils/hash.py index a9cbae1885..0ff7a29cef 100644 --- a/samcli/lib/utils/hash.py +++ b/samcli/lib/utils/hash.py @@ -3,42 +3,47 @@ """ import os import hashlib -from typing import List, Optional +from typing import Any, cast, List, Optional BLOCK_SIZE = 4096 -def file_checksum(file_name: str) -> str: +def file_checksum(file_name: str, hash_generator: Any = None) -> str: """ Parameters ---------- file_name: file name of the file for which md5 checksum is required. + hash_generator: hashlib _Hash object for generating hashes. Defaults to hashlib.md5. + Returns ------- - md5 checksum of the given file. + checksum of the given file. """ + # Default value is set here because default values are static mutable in Python + if not hash_generator: + hash_generator = hashlib.md5() with open(file_name, "rb") as file_handle: - md5 = hashlib.md5() - # Save current cursor position and reset cursor to start of file curpos = file_handle.tell() file_handle.seek(0) buf = file_handle.read(BLOCK_SIZE) while buf: - md5.update(buf) + hash_generator.update(buf) buf = file_handle.read(BLOCK_SIZE) # Restore file cursor's position file_handle.seek(curpos) - return md5.hexdigest() + return cast(str, hash_generator.hexdigest()) -def dir_checksum(directory: str, followlinks: bool = True, ignore_list: Optional[List[str]] = None) -> str: +def dir_checksum( + directory: str, followlinks: bool = True, ignore_list: Optional[List[str]] = None, hash_generator: Any = None +) -> str: """ Parameters @@ -46,14 +51,16 @@ def dir_checksum(directory: str, followlinks: bool = True, ignore_list: Optional directory : A directory with an absolute path followlinks: Follow symbolic links through the given directory ignore_list: The list of file/directory names to ignore in checksum + hash_generator: The hashing method (hashlib _Hash object) that generates checksum. Defaults to hashlib.md5. Returns ------- - md5 checksum of the directory. + checksum hash of the directory. """ ignore_set = set(ignore_list or []) - md5_dir = hashlib.md5() + if not hash_generator: + hash_generator = hashlib.md5() files = list() # Walk through given directory and find all directories and files. for dirpath, dirnames, filenames in os.walk(directory, followlinks=followlinks): @@ -70,11 +77,11 @@ def dir_checksum(directory: str, followlinks: bool = True, ignore_list: Optional files.sort() for file in files: - md5_dir.update(os.path.relpath(file, directory).encode("utf-8")) + hash_generator.update(os.path.relpath(file, directory).encode("utf-8")) filepath_checksum = file_checksum(file) - md5_dir.update(filepath_checksum.encode("utf-8")) + hash_generator.update(filepath_checksum.encode("utf-8")) - return md5_dir.hexdigest() + return cast(str, hash_generator.hexdigest()) def str_checksum(content: str) -> str: diff --git a/samcli/lib/utils/lock_distributor.py b/samcli/lib/utils/lock_distributor.py new file mode 100644 index 0000000000..80d53edad0 --- /dev/null +++ b/samcli/lib/utils/lock_distributor.py @@ -0,0 +1,142 @@ +"""LockDistributor for creating and managing a set of locks""" +import threading +import multiprocessing +import multiprocessing.managers +from typing import Dict, List, Optional, cast +from enum import Enum, auto + + +class LockChain: + """Wrapper class for acquiring multiple locks in the same order to prevent dead locks + Can be used with `with` statement""" + + def __init__(self, lock_mapping: Dict[str, threading.Lock]): + """ + Parameters + ---------- + lock_mapping : Dict[str, threading.Lock] + Dictionary of locks with keys being used as generating reproduciable order for aquiring and releasing locks. + """ + self._locks = [value for _, value in sorted(lock_mapping.items())] + + def acquire(self) -> None: + """Aquire all locks in the LockChain""" + for lock in self._locks: + lock.acquire() + + def release(self) -> None: + """Release all locks in the LockChain""" + for lock in self._locks: + lock.release() + + def __enter__(self) -> "LockChain": + self.acquire() + return self + + def __exit__(self, exception_type, exception_value, traceback) -> None: + self.release() + + +class LockDistributorType(Enum): + """Types of LockDistributor""" + + THREAD = auto() + PROCESS = auto() + + +class LockDistributor: + """Dynamic lock distributor that supports threads and processes. + In the case of processes, both manager(server process) or shared memory can be used. + """ + + _lock_type: LockDistributorType + _manager: Optional[multiprocessing.managers.SyncManager] + _dict_lock: threading.Lock + _locks: Dict[str, threading.Lock] + + def __init__( + self, + lock_type: LockDistributorType = LockDistributorType.THREAD, + manager: Optional[multiprocessing.managers.SyncManager] = None, + ): + """[summary] + + Parameters + ---------- + lock_type : LockDistributorType, optional + Whether locking with threads or processes, by default LockDistributorType.THREAD + manager : Optional[multiprocessing.managers.SyncManager], optional + Optional process sync mananger for creating proxy locks, by default None + """ + self._lock_type = lock_type + self._manager = manager + self._dict_lock = self._create_new_lock() + self._locks = ( + self._manager.dict() + if self._lock_type == LockDistributorType.PROCESS and self._manager is not None + else dict() + ) + + def _create_new_lock(self) -> threading.Lock: + """Create a new lock based on lock type + + Returns + ------- + threading.Lock + Newly created lock + """ + if self._lock_type == LockDistributorType.THREAD: + return threading.Lock() + + return self._manager.Lock() if self._manager is not None else cast(threading.Lock, multiprocessing.Lock()) + + def get_lock(self, key: str) -> threading.Lock: + """Retrieve a lock associating with the key + If the lock does not exist, a new lock will be created. + + Parameters + ---------- + key : Key for retrieving the lock + + Returns + ------- + threading.Lock + Lock associated with the key + """ + with self._dict_lock: + if key not in self._locks: + self._locks[key] = self._create_new_lock() + return self._locks[key] + + def get_locks(self, keys: List[str]) -> Dict[str, threading.Lock]: + """Retrieve a list of locks associating with keys + + Parameters + ---------- + keys : List[str] + List of keys for retrieving the locks + + Returns + ------- + Dict[str, threading.Lock] + Dictionary mapping keys to locks + """ + lock_mapping = dict() + for key in keys: + lock_mapping[key] = self.get_lock(key) + return lock_mapping + + def get_lock_chain(self, keys: List[str]) -> LockChain: + """Similar to get_locks, but retrieves a LockChain object instead of a dictionary + + Parameters + ---------- + keys : List[str] + List of keys for retrieving the locks + + Returns + ------- + LockChain + LockChain object containing all the locks associated with keys + """ + return LockChain(self.get_locks(keys)) diff --git a/samcli/lib/utils/osutils.py b/samcli/lib/utils/osutils.py index 68b6fa02d1..174b6ffb50 100644 --- a/samcli/lib/utils/osutils.py +++ b/samcli/lib/utils/osutils.py @@ -12,6 +12,10 @@ LOG = logging.getLogger(__name__) +# Build directories need not be world writable. +# This is usually a optimal permission for directories +BUILD_DIR_PERMISSIONS = 0o755 + @contextmanager def mkdir_temp(mode=0o755, ignore_errors=False): diff --git a/samcli/lib/utils/path_observer.py b/samcli/lib/utils/path_observer.py new file mode 100644 index 0000000000..615c8963fa --- /dev/null +++ b/samcli/lib/utils/path_observer.py @@ -0,0 +1,161 @@ +""" +HandlerObserver and its helper classes. +""" +import re + +from pathlib import Path +from typing import Callable, List, Optional +from dataclasses import dataclass + +from watchdog.observers import Observer +from watchdog.events import ( + FileSystemEvent, + FileSystemEventHandler, + RegexMatchingEventHandler, +) +from watchdog.observers.api import DEFAULT_OBSERVER_TIMEOUT, ObservedWatch + + +@dataclass +class PathHandler: + """PathHandler is an object that can be passed into + Bundle Observer directly for watching a specific path with + corresponding EventHandler + + Fields: + event_handler : FileSystemEventHandler + Handler for the event + path : Path + Path to the folder to be watched + recursive : bool, optional + True to watch child folders, by default False + static_folder : bool, optional + Should the observed folder name be static, by default False + See StaticFolderWrapper on the use case. + self_create : Optional[Callable[[], None]], optional + Callback when the folder to be observed itself is created, by default None + This will not be called if static_folder is False + self_delete : Optional[Callable[[], None]], optional + Callback when the folder to be observed itself is deleted, by default None + This will not be called if static_folder is False + """ + + event_handler: FileSystemEventHandler + path: Path + recursive: bool = False + static_folder: bool = False + self_create: Optional[Callable[[], None]] = None + self_delete: Optional[Callable[[], None]] = None + + +class StaticFolderWrapper: + """This class is used to alter the behavior of watchdog folder watches. + https://github.com/gorakhargosh/watchdog/issues/415 + By default, if a folder is renamed, the handler will still get triggered for the new folder + Ex: + 1. Create FolderA + 2. Watch FolderA + 3. Rename FolderA to FolderB + 4. Add file to FolderB + 5. Handler will get event for adding the file to FolderB but with event path still as FolderA + This class watches the parent folder and if the folder to be watched gets renamed or deleted, + the watch will be stopped and changes in the renamed folder will not be triggered. + """ + + def __init__(self, observer: "HandlerObserver", initial_watch: ObservedWatch, path_handler: PathHandler): + """[summary] + + Parameters + ---------- + observer : HandlerObserver + HandlerObserver + initial_watch : ObservedWatch + Initial watch for the folder to be watched that gets returned by HandlerObserver + path_handler : PathHandler + PathHandler of the folder to be watched. + """ + self._observer = observer + self._path_handler = path_handler + self._watch = initial_watch + + def _on_parent_change(self, _: FileSystemEvent) -> None: + """Callback for changes detected in the parent folder""" + + # When folder is being watched but the folder does not exist + if self._watch and not self._path_handler.path.exists(): + if self._path_handler.self_delete: + self._path_handler.self_delete() + self._observer.unschedule(self._watch) + self._watch = None + # When folder is not being watched but the folder does exist + elif not self._watch and self._path_handler.path.exists(): + if self._path_handler.self_create: + self._path_handler.self_create() + self._watch = self._observer.schedule_handler(self._path_handler) + + def get_dir_parent_path_handler(self) -> PathHandler: + """Get PathHandler that watches the folder changes from the parent folder. + + Returns + ------- + PathHandler + PathHandler for the parent folder. This should be added back into the HandlerObserver. + """ + dir_path = self._path_handler.path.resolve() + parent_dir_path = dir_path.parent + parent_folder_handler = RegexMatchingEventHandler( + regexes=[f"^{re.escape(str(dir_path))}$"], + ignore_regexes=[], + ignore_directories=False, + case_sensitive=True, + ) + parent_folder_handler.on_any_event = self._on_parent_change + return PathHandler(path=parent_dir_path, event_handler=parent_folder_handler) + + +class HandlerObserver(Observer): # pylint: disable=too-many-ancestors + """ + Extended WatchDog Observer that takes in a single PathHandler object. + """ + + def __init__(self, timeout=DEFAULT_OBSERVER_TIMEOUT): + super().__init__(timeout=timeout) + + def schedule_handlers(self, path_handlers: List[PathHandler]) -> List[ObservedWatch]: + """Schedule a list of PathHandlers + + Parameters + ---------- + path_handlers : List[PathHandler] + List of PathHandlers to be scheduled + + Returns + ------- + List[ObservedWatch] + List of ObservedWatch corresponding to path_handlers in the same order. + """ + watches = list() + for path_handler in path_handlers: + watches.append(self.schedule_handler(path_handler)) + return watches + + def schedule_handler(self, path_handler: PathHandler) -> ObservedWatch: + """Schedule a PathHandler + + Parameters + ---------- + path_handler : PathHandler + PathHandler to be scheduled + + Returns + ------- + ObservedWatch + ObservedWatch corresponding to the PathHandler. + If static_folder is True, the parent folder watch will be returned instead. + """ + watch = self.schedule(path_handler.event_handler, str(path_handler.path), path_handler.recursive) + if path_handler.static_folder: + static_wrapper = StaticFolderWrapper(self, watch, path_handler) + parent_path_handler = static_wrapper.get_dir_parent_path_handler() + watch = self.schedule_handler(parent_path_handler) + return watch diff --git a/samcli/lib/utils/resource_trigger.py b/samcli/lib/utils/resource_trigger.py new file mode 100644 index 0000000000..eb9f7f31a7 --- /dev/null +++ b/samcli/lib/utils/resource_trigger.py @@ -0,0 +1,339 @@ +"""ResourceTrigger Classes for Creating PathHandlers According to a Resource""" +import re +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, Dict, List, Optional, cast + +from typing_extensions import Protocol +from watchdog.events import FileSystemEvent, PatternMatchingEventHandler, RegexMatchingEventHandler + +from samcli.lib.providers.exceptions import MissingCodeUri, MissingLocalDefinition +from samcli.lib.providers.provider import Function, LayerVersion, ResourceIdentifier, Stack, get_resource_by_id +from samcli.lib.providers.sam_function_provider import SamFunctionProvider +from samcli.lib.providers.sam_layer_provider import SamLayerProvider +from samcli.lib.utils.definition_validator import DefinitionValidator +from samcli.lib.utils.path_observer import PathHandler +from samcli.local.lambdafn.exceptions import FunctionNotFound, ResourceNotFound +from samcli.lib.utils.resources import RESOURCES_WITH_LOCAL_PATHS + + +class OnChangeCallback(Protocol): + """Callback Type""" + + def __call__(self, event: Optional[FileSystemEvent] = None) -> None: + pass + + +class ResourceTrigger(ABC): + """Abstract class for creating PathHandlers for a resource. + PathHandlers returned by get_path_handlers() can then be used with an observer for + detecting file changes associated with the resource.""" + + def __init__(self) -> None: + pass + + @abstractmethod + def get_path_handlers(self) -> List[PathHandler]: + """List of PathHandlers that corresponds to a resource + Returns + ------- + List[PathHandler] + List of PathHandlers that corresponds to a resource + """ + raise NotImplementedError("get_path_handleres is not implemented.") + + @staticmethod + def get_single_file_path_handler(file_path_str: str) -> PathHandler: + """Get PathHandler for watching a single file + + Parameters + ---------- + file_path_str : str + File path in string + + Returns + ------- + PathHandler + The PathHandler for the file specified + """ + file_path = Path(file_path_str).resolve() + folder_path = file_path.parent + file_handler = RegexMatchingEventHandler( + regexes=[f"^{re.escape(str(file_path))}$"], ignore_regexes=[], ignore_directories=True, case_sensitive=True + ) + return PathHandler(path=folder_path, event_handler=file_handler, recursive=False) + + @staticmethod + def get_dir_path_handler(dir_path_str: str) -> PathHandler: + """Get PathHandler for watching a single directory + + Parameters + ---------- + dir_path_str : str + Folder path in string + + Returns + ------- + PathHandler + The PathHandler for the folder specified + """ + dir_path = Path(dir_path_str).resolve() + file_handler = PatternMatchingEventHandler( + patterns=["*"], ignore_patterns=[], ignore_directories=False, case_sensitive=True + ) + return PathHandler(path=dir_path, event_handler=file_handler, recursive=True, static_folder=True) + + +class TemplateTrigger(ResourceTrigger): + _template_file: str + _on_template_change: OnChangeCallback + _validator: DefinitionValidator + + def __init__(self, template_file: str, on_template_change: OnChangeCallback) -> None: + """ + Parameters + ---------- + template_file : str + Template file to be watched + on_template_change : OnChangeCallback + Callback when template changes + """ + super().__init__() + self._template_file = template_file + self._on_template_change = on_template_change + self._validator = DefinitionValidator(Path(self._template_file)) + + def _validator_wrapper(self, event: Optional[FileSystemEvent] = None) -> None: + """Wrapper for callback that only executes if the template is valid and non-trivial changes are detected. + + Parameters + ---------- + event : Optional[FileSystemEvent], optional + """ + if self._validator.validate(): + self._on_template_change(event) + + def get_path_handlers(self) -> List[PathHandler]: + file_path_handler = ResourceTrigger.get_single_file_path_handler(self._template_file) + file_path_handler.event_handler.on_any_event = self._validator_wrapper + return [file_path_handler] + + +class CodeResourceTrigger(ResourceTrigger): + """Parent class for ResourceTriggers that are for a single template resource.""" + + _resource_identifier: ResourceIdentifier + _resource: Dict[str, Any] + _on_code_change: OnChangeCallback + + def __init__(self, resource_identifier: ResourceIdentifier, stacks: List[Stack], on_code_change: OnChangeCallback): + """ + Parameters + ---------- + resource_identifier : ResourceIdentifier + ResourceIdentifier + stacks : List[Stack] + List of stacks + on_code_change : OnChangeCallback + Callback when the resource files are changed. + + Raises + ------ + ResourceNotFound + Raised when the resource cannot be found in the stacks. + """ + super().__init__() + self._resource_identifier = resource_identifier + resource = get_resource_by_id(stacks, resource_identifier) + if not resource: + raise ResourceNotFound() + self._resource = resource + self._on_code_change = on_code_change + + +class LambdaFunctionCodeTrigger(CodeResourceTrigger): + _function: Function + _code_uri: str + + def __init__(self, function_identifier: ResourceIdentifier, stacks: List[Stack], on_code_change: OnChangeCallback): + """ + Parameters + ---------- + function_identifier : ResourceIdentifier + ResourceIdentifier for the function + stacks : List[Stack] + List of stacks + on_code_change : OnChangeCallback + Callback when function code files are changed. + + Raises + ------ + FunctionNotFound + raised when the function cannot be found in stacks + MissingCodeUri + raised when there is no CodeUri property in the function definition. + """ + super().__init__(function_identifier, stacks, on_code_change) + function = SamFunctionProvider(stacks).get(str(function_identifier)) + if not function: + raise FunctionNotFound() + self._function = function + + code_uri = self._get_code_uri() + if not code_uri: + raise MissingCodeUri() + self._code_uri = code_uri + + @abstractmethod + def _get_code_uri(self) -> Optional[str]: + """ + Returns + ------- + Optional[str] + Path for the folder to be watched. + """ + raise NotImplementedError() + + def get_path_handlers(self) -> List[PathHandler]: + """ + Returns + ------- + List[PathHandler] + PathHandlers for the code folder associated with the function + """ + dir_path_handler = ResourceTrigger.get_dir_path_handler(self._code_uri) + dir_path_handler.self_create = self._on_code_change + dir_path_handler.self_delete = self._on_code_change + dir_path_handler.event_handler.on_any_event = self._on_code_change + return [dir_path_handler] + + +class LambdaZipCodeTrigger(LambdaFunctionCodeTrigger): + def _get_code_uri(self) -> Optional[str]: + return self._function.codeuri + + +class LambdaImageCodeTrigger(LambdaFunctionCodeTrigger): + def _get_code_uri(self) -> Optional[str]: + if not self._function.metadata: + return None + return cast(Optional[str], self._function.metadata.get("DockerContext", None)) + + +class LambdaLayerCodeTrigger(CodeResourceTrigger): + _layer: LayerVersion + _code_uri: str + + def __init__( + self, + layer_identifier: ResourceIdentifier, + stacks: List[Stack], + on_code_change: OnChangeCallback, + ): + """ + Parameters + ---------- + layer_identifier : ResourceIdentifier + ResourceIdentifier for the layer + stacks : List[Stack] + List of stacks + on_code_change : OnChangeCallback + Callback when layer code files are changed. + + Raises + ------ + ResourceNotFound + raised when the layer cannot be found in stacks + MissingCodeUri + raised when there is no CodeUri property in the function definition. + """ + super().__init__(layer_identifier, stacks, on_code_change) + layer = SamLayerProvider(stacks).get(str(layer_identifier)) + if not layer: + raise ResourceNotFound() + self._layer = layer + code_uri = self._layer.codeuri + if not code_uri: + raise MissingCodeUri() + self._code_uri = code_uri + + def get_path_handlers(self) -> List[PathHandler]: + """ + Returns + ------- + List[PathHandler] + PathHandlers for the code folder associated with the layer + """ + dir_path_handler = ResourceTrigger.get_dir_path_handler(self._code_uri) + dir_path_handler.self_create = self._on_code_change + dir_path_handler.self_delete = self._on_code_change + dir_path_handler.event_handler.on_any_event = self._on_code_change + return [dir_path_handler] + + +class DefinitionCodeTrigger(CodeResourceTrigger): + _validator: DefinitionValidator + _definition_file: str + + def __init__( + self, + resource_identifier: ResourceIdentifier, + resource_type: str, + stacks: List[Stack], + on_code_change: OnChangeCallback, + ): + """ + Parameters + ---------- + resource_identifier : ResourceIdentifier + ResourceIdentifier for the Resource + resource_type : str + Resource type + stacks : List[Stack] + List of stacks + on_code_change : OnChangeCallback + Callback when definition file is changed. + """ + super().__init__(resource_identifier, stacks, on_code_change) + self._resource_type = resource_type + self._definition_file = self._get_definition_file() + self._validator = DefinitionValidator(Path(self._definition_file)) + + def _get_definition_file(self) -> str: + """ + Returns + ------- + str + JSON/YAML definition file path + + Raises + ------ + MissingLocalDefinition + raised when resource property related to definition path is not specified. + """ + property_name = RESOURCES_WITH_LOCAL_PATHS[self._resource_type][0] + definition_file = self._resource.get("Properties", {}).get(property_name) + if not definition_file or not isinstance(definition_file, str): + raise MissingLocalDefinition(self._resource_identifier, property_name) + return definition_file + + def _validator_wrapper(self, event: Optional[FileSystemEvent] = None): + """Wrapper for callback that only executes if the definition is valid and non-trivial changes are detected. + + Parameters + ---------- + event : Optional[FileSystemEvent], optional + """ + if self._validator.validate(): + self._on_code_change(event) + + def get_path_handlers(self) -> List[PathHandler]: + """ + Returns + ------- + List[PathHandler] + A single PathHandler for watching the definition file. + """ + file_path_handler = ResourceTrigger.get_single_file_path_handler(self._definition_file) + file_path_handler.event_handler.on_any_event = self._validator_wrapper + return [file_path_handler] diff --git a/samcli/lib/utils/resource_type_based_factory.py b/samcli/lib/utils/resource_type_based_factory.py new file mode 100644 index 0000000000..67a46f08af --- /dev/null +++ b/samcli/lib/utils/resource_type_based_factory.py @@ -0,0 +1,69 @@ +"""Base Factory Abstract Class for Creating Objects Specific to a Resource Type""" +import logging +from abc import ABC, abstractmethod +from typing import Callable, Dict, Generic, List, Optional, TypeVar + +from samcli.lib.providers.provider import ResourceIdentifier, Stack, get_resource_by_id + +LOG = logging.getLogger(__name__) + +T = TypeVar("T") # pylint: disable=invalid-name + + +class ResourceTypeBasedFactory(ABC, Generic[T]): + def __init__(self, stacks: List[Stack]) -> None: + self._stacks = stacks + + @abstractmethod + def _get_generator_mapping(self) -> Dict[str, Callable]: + """ + Returns + ------- + Dict[str, GeneratorFunction] + Mapping between resource type and generator function + """ + raise NotImplementedError() + + def _get_resource_type(self, resource_identifier: ResourceIdentifier) -> Optional[str]: + """Get resource type of the resource + + Parameters + ---------- + resource_identifier : ResourceIdentifier + + Returns + ------- + Optional[str] + Resource type of the resource + """ + resource = get_resource_by_id(self._stacks, resource_identifier) + if not resource: + LOG.debug("Resource %s does not exist.", str(resource_identifier)) + return None + + resource_type = resource.get("Type", None) + if not isinstance(resource_type, str): + LOG.debug("Resource %s has none string property Type.", str(resource_identifier)) + return None + return resource_type + + def _get_generator_function(self, resource_identifier: ResourceIdentifier) -> Optional[Callable]: + """Create an appropriate T object based on stack resource type + + Parameters + ---------- + resource_identifier : ResourceIdentifier + Resource identifier of the resource + + Returns + ------- + Optional[T] + Object T for the resource. Returns None if resource cannot be + found or have no associating T generator function. + """ + resource_type = self._get_resource_type(resource_identifier) + if not resource_type: + LOG.debug("Resource %s has invalid property Type.", str(resource_identifier)) + return None + generator = self._get_generator_mapping().get(resource_type, None) + return generator diff --git a/samcli/commands/_utils/resources.py b/samcli/lib/utils/resources.py similarity index 85% rename from samcli/commands/_utils/resources.py rename to samcli/lib/utils/resources.py index d68c675288..9fc9705711 100644 --- a/samcli/commands/_utils/resources.py +++ b/samcli/lib/utils/resources.py @@ -1,26 +1,49 @@ """ -Enums for Resources and thier Location Properties, along with utility functions +Enums for Resources and their Location Properties, along with utility functions """ from collections import defaultdict -AWS_SERVERLESSREPO_APPLICATION = "AWS::ServerlessRepo::Application" +# Lambda AWS_SERVERLESS_FUNCTION = "AWS::Serverless::Function" +AWS_SERVERLESS_LAYERVERSION = "AWS::Serverless::LayerVersion" + +AWS_LAMBDA_FUNCTION = "AWS::Lambda::Function" +AWS_LAMBDA_LAYERVERSION = "AWS::Lambda::LayerVersion" + +# APIGW AWS_SERVERLESS_API = "AWS::Serverless::Api" AWS_SERVERLESS_HTTPAPI = "AWS::Serverless::HttpApi" + +AWS_APIGATEWAY_RESTAPI = "AWS::ApiGateway::RestApi" +AWS_APIGATEWAY_STAGE = "AWS::ApiGateway::Stage" +AWS_APIGATEWAY_RESOURCE = "AWS::ApiGateway::Resource" +AWS_APIGATEWAY_METHOD = "AWS::ApiGateway::Method" + +AWS_APIGATEWAY_V2_API = "AWS::ApiGatewayV2::Api" +AWS_APIGATEWAY_V2_INTEGRATION = "AWS::ApiGatewayV2::Integration" +AWS_APIGATEWAY_V2_ROUTE = "AWS::ApiGatewayV2::Route" +AWS_APIGATEWAY_V2_STAGE = "AWS::ApiGatewayV2::Stage" + +# SFN +AWS_SERVERLESS_STATEMACHINE = "AWS::Serverless::StateMachine" + +AWS_STEPFUNCTIONS_STATEMACHINE = "AWS::StepFunctions::StateMachine" + +# Others +AWS_SERVERLESS_APPLICATION = "AWS::Serverless::Application" + +AWS_SERVERLESSREPO_APPLICATION = "AWS::ServerlessRepo::Application" AWS_APPSYNC_GRAPHQLSCHEMA = "AWS::AppSync::GraphQLSchema" AWS_APPSYNC_RESOLVER = "AWS::AppSync::Resolver" AWS_APPSYNC_FUNCTIONCONFIGURATION = "AWS::AppSync::FunctionConfiguration" -AWS_LAMBDA_FUNCTION = "AWS::Lambda::Function" -AWS_APIGATEWAY_RESTAPI = "AWS::ApiGateway::RestApi" AWS_ELASTICBEANSTALK_APPLICATIONVERSION = "AWS::ElasticBeanstalk::ApplicationVersion" AWS_CLOUDFORMATION_MODULEVERSION = "AWS::CloudFormation::ModuleVersion" AWS_CLOUDFORMATION_RESOURCEVERSION = "AWS::CloudFormation::ResourceVersion" AWS_CLOUDFORMATION_STACK = "AWS::CloudFormation::Stack" -AWS_SERVERLESS_APPLICATION = "AWS::Serverless::Application" -AWS_LAMBDA_LAYERVERSION = "AWS::Lambda::LayerVersion" -AWS_SERVERLESS_LAYERVERSION = "AWS::Serverless::LayerVersion" AWS_GLUE_JOB = "AWS::Glue::Job" +AWS_SQS_QUEUE = "AWS::SQS::Queue" +AWS_KINESIS_STREAM = "AWS::Kinesis::Stream" AWS_SERVERLESS_STATEMACHINE = "AWS::Serverless::StateMachine" AWS_STEPFUNCTIONS_STATEMACHINE = "AWS::StepFunctions::StateMachine" AWS_ECR_REPOSITORY = "AWS::ECR::Repository" diff --git a/samcli/lib/utils/version_checker.py b/samcli/lib/utils/version_checker.py index 96b83dc54b..130f5f3260 100644 --- a/samcli/lib/utils/version_checker.py +++ b/samcli/lib/utils/version_checker.py @@ -4,7 +4,6 @@ import logging from datetime import datetime, timedelta from functools import wraps -from typing import Optional import click from requests import get @@ -77,7 +76,7 @@ def _inform_newer_version(force_check=False) -> None: LOG.debug("New version check failed", exc_info=e) finally: if need_to_update_last_check_time: - update_last_check_time(global_config) + update_last_check_time() def fetch_and_compare_versions() -> None: @@ -95,17 +94,13 @@ def fetch_and_compare_versions() -> None: click.echo(f"To download: {AWS_SAM_CLI_INSTALL_DOCS}", err=True) -def update_last_check_time(global_config: Optional[GlobalConfig]) -> None: +def update_last_check_time() -> None: """ Update last_check_time in GlobalConfig - Parameters - ---------- - global_config: GlobalConfig - GlobalConfig object that have been read """ try: - if global_config: - global_config.last_version_check = datetime.utcnow().timestamp() + gc = GlobalConfig() + gc.last_version_check = datetime.utcnow().timestamp() except Exception as e: LOG.debug("Updating last version check time was failed", exc_info=e) diff --git a/setup.py b/setup.py index 307fd6dec2..1140907370 100644 --- a/setup.py +++ b/setup.py @@ -36,7 +36,8 @@ def read_version(): return re.search(r"__version__ = \"([^']+)\"", content).group(1) -cmd_name = "sam" +# TODO(wchengru): The cmd name is for beta release only, need to change back to "sam" when GA +cmd_name = "sam-acc" if os.getenv("SAM_CLI_DEV"): # We are installing in a dev environment cmd_name = "samdev" diff --git a/tests/functional/commands/cli/test_global_config.py b/tests/functional/commands/cli/test_global_config.py index d9bb9907b2..458b32d74f 100644 --- a/tests/functional/commands/cli/test_global_config.py +++ b/tests/functional/commands/cli/test_global_config.py @@ -4,7 +4,7 @@ import os from time import time -from unittest.mock import mock_open, patch +from unittest.mock import MagicMock, patch from unittest import TestCase from samcli.cli.global_config import GlobalConfig from pathlib import Path @@ -13,15 +13,20 @@ class TestGlobalConfig(TestCase): def setUp(self): self._cfg_dir = tempfile.mkdtemp() - self._previous_telemetry_environ = os.environ.get("SAM_CLI_TELEMETRY") - os.environ.pop("SAM_CLI_TELEMETRY") + if "SAM_CLI_TELEMETRY" in os.environ: + os.environ.pop("SAM_CLI_TELEMETRY") + self.saved_env_var = dict(os.environ) def tearDown(self): shutil.rmtree(self._cfg_dir) - os.environ["SAM_CLI_TELEMETRY"] = self._previous_telemetry_environ + # Force singleton to recreate after each test + GlobalConfig._Singleton__instance = None + os.environ.clear() + os.environ.update(self.saved_env_var) def test_installation_id_with_side_effect(self): - gc = GlobalConfig(config_dir=self._cfg_dir) + gc = GlobalConfig() + gc.config_dir = Path(self._cfg_dir) installation_id = gc.installation_id expected_path = Path(self._cfg_dir, "metadata.json") json_body = json.loads(expected_path.read_text()) @@ -36,7 +41,8 @@ def test_installation_id_on_existing_file(self): with open(str(path), "w") as f: cfg = {"foo": "bar"} f.write(json.dumps(cfg, indent=4) + "\n") - gc = GlobalConfig(config_dir=self._cfg_dir) + gc = GlobalConfig() + gc.config_dir = Path(self._cfg_dir) installation_id = gc.installation_id json_body = json.loads(path.read_text()) self.assertEqual(installation_id, json_body["installationId"]) @@ -47,25 +53,24 @@ def test_installation_id_exists(self): with open(str(path), "w") as f: cfg = {"installationId": "stub-uuid"} f.write(json.dumps(cfg, indent=4) + "\n") - gc = GlobalConfig(config_dir=self._cfg_dir) + gc = GlobalConfig() + gc.config_dir = Path(self._cfg_dir) installation_id = gc.installation_id self.assertEqual("stub-uuid", installation_id) - def test_init_override(self): - gc = GlobalConfig(installation_id="foo") - installation_id = gc.installation_id - self.assertEqual("foo", installation_id) - def test_invalid_json(self): path = Path(self._cfg_dir, "metadata.json") with open(str(path), "w") as f: f.write("NOT JSON, PROBABLY VALID YAML AM I RIGHT!?") - gc = GlobalConfig(config_dir=self._cfg_dir) - self.assertIsNone(gc.installation_id) + gc = GlobalConfig() + gc.config_dir = Path(self._cfg_dir) + self.assertIsInstance(gc.installation_id, str) self.assertFalse(gc.telemetry_enabled) def test_telemetry_flag_provided(self): - gc = GlobalConfig(telemetry_enabled=True) + gc = GlobalConfig() + gc.config_dir = Path(self._cfg_dir) + gc.telemetry_enabled = True self.assertTrue(gc.telemetry_enabled) def test_telemetry_flag_from_cfg(self): @@ -73,11 +78,13 @@ def test_telemetry_flag_from_cfg(self): with open(str(path), "w") as f: cfg = {"telemetryEnabled": True} f.write(json.dumps(cfg, indent=4) + "\n") - gc = GlobalConfig(config_dir=self._cfg_dir) + gc = GlobalConfig() + gc.config_dir = Path(self._cfg_dir) self.assertTrue(gc.telemetry_enabled) def test_telemetry_flag_no_file(self): - gc = GlobalConfig(config_dir=self._cfg_dir) + gc = GlobalConfig() + gc.config_dir = Path(self._cfg_dir) self.assertFalse(gc.telemetry_enabled) def test_telemetry_flag_not_in_cfg(self): @@ -85,12 +92,14 @@ def test_telemetry_flag_not_in_cfg(self): with open(str(path), "w") as f: cfg = {"installationId": "stub-uuid"} f.write(json.dumps(cfg, indent=4) + "\n") - gc = GlobalConfig(config_dir=self._cfg_dir) + gc = GlobalConfig() + gc.config_dir = Path(self._cfg_dir) self.assertFalse(gc.telemetry_enabled) def test_set_telemetry_flag_no_file(self): path = Path(self._cfg_dir, "metadata.json") - gc = GlobalConfig(config_dir=self._cfg_dir) + gc = GlobalConfig() + gc.config_dir = Path(self._cfg_dir) self.assertFalse(gc.telemetry_enabled) # pre-state test gc.telemetry_enabled = True from_gc = gc.telemetry_enabled @@ -104,7 +113,8 @@ def test_set_telemetry_flag_no_key(self): with open(str(path), "w") as f: cfg = {"installationId": "stub-uuid"} f.write(json.dumps(cfg, indent=4) + "\n") - gc = GlobalConfig(config_dir=self._cfg_dir) + gc = GlobalConfig() + gc.config_dir = Path(self._cfg_dir) gc.telemetry_enabled = True json_body = json.loads(path.read_text()) self.assertTrue(gc.telemetry_enabled) @@ -115,7 +125,8 @@ def test_set_telemetry_flag_overwrite(self): with open(str(path), "w") as f: cfg = {"telemetryEnabled": True} f.write(json.dumps(cfg, indent=4) + "\n") - gc = GlobalConfig(config_dir=self._cfg_dir) + gc = GlobalConfig() + gc.config_dir = Path(self._cfg_dir) self.assertTrue(gc.telemetry_enabled) gc.telemetry_enabled = False json_body = json.loads(path.read_text()) @@ -127,12 +138,16 @@ def test_telemetry_flag_explicit_false(self): with open(str(path), "w") as f: cfg = {"telemetryEnabled": True} f.write(json.dumps(cfg, indent=4) + "\n") - gc = GlobalConfig(config_dir=self._cfg_dir, telemetry_enabled=False) + gc = GlobalConfig() + gc.config_dir = Path(self._cfg_dir) + gc.telemetry_enabled = False self.assertFalse(gc.telemetry_enabled) def test_last_version_check_value_provided(self): last_version_check_value = time() - gc = GlobalConfig(last_version_check=last_version_check_value) + gc = GlobalConfig() + gc.config_dir = Path(self._cfg_dir) + gc.last_version_check = last_version_check_value self.assertEqual(gc.last_version_check, last_version_check_value) def test_last_version_check_value_cfg(self): @@ -141,11 +156,13 @@ def test_last_version_check_value_cfg(self): with open(str(path), "w") as f: cfg = {"lastVersionCheck": last_version_check_value} f.write(json.dumps(cfg, indent=4) + "\n") - gc = GlobalConfig(config_dir=self._cfg_dir) + gc = GlobalConfig() + gc.config_dir = Path(self._cfg_dir) self.assertEqual(gc.last_version_check, last_version_check_value) def test_last_version_check_value_no_file(self): - gc = GlobalConfig(config_dir=self._cfg_dir) + gc = GlobalConfig() + gc.config_dir = Path(self._cfg_dir) self.assertIsNone(gc.last_version_check) def test_last_version_check_value_not_in_cfg(self): @@ -153,12 +170,14 @@ def test_last_version_check_value_not_in_cfg(self): with open(str(path), "w") as f: cfg = {"installationId": "stub-uuid"} f.write(json.dumps(cfg, indent=4) + "\n") - gc = GlobalConfig(config_dir=self._cfg_dir) + gc = GlobalConfig() + gc.config_dir = Path(self._cfg_dir) self.assertIsNone(gc.last_version_check) def test_set_last_version_check_value_no_file(self): path = Path(self._cfg_dir, "metadata.json") - gc = GlobalConfig(config_dir=self._cfg_dir) + gc = GlobalConfig() + gc.config_dir = Path(self._cfg_dir) self.assertIsNone(gc.last_version_check) # pre-state test last_version_check_value = time() @@ -173,7 +192,8 @@ def test_last_version_check_value_no_key(self): with open(str(path), "w") as f: cfg = {"installationId": "stub-uuid"} f.write(json.dumps(cfg, indent=4) + "\n") - gc = GlobalConfig(config_dir=self._cfg_dir) + gc = GlobalConfig() + gc.config_dir = Path(self._cfg_dir) last_version_check_value = time() gc.last_version_check = last_version_check_value @@ -187,7 +207,8 @@ def test_set_last_version_check_value_overwrite(self): cfg = {"lastVersionCheck": last_version_check_value} f.write(json.dumps(cfg, indent=4) + "\n") - gc = GlobalConfig(config_dir=self._cfg_dir) + gc = GlobalConfig() + gc.config_dir = Path(self._cfg_dir) self.assertEqual(gc.last_version_check, last_version_check_value) last_version_check_new_value = time() @@ -202,25 +223,29 @@ def test_last_version_check_explicit_value(self): with open(str(path), "w") as f: cfg = {"lastVersionCheck": last_version_check_value} f.write(json.dumps(cfg, indent=4) + "\n") - gc = GlobalConfig(config_dir=self._cfg_dir, last_version_check=last_version_check_value_override) + gc = GlobalConfig() + gc.config_dir = Path(self._cfg_dir) + gc.last_version_check = last_version_check_value_override self.assertEqual(gc.last_version_check, last_version_check_value_override) - def test_setter_raises_on_invalid_json(self): + def test_setter_on_invalid_json(self): path = Path(self._cfg_dir, "metadata.json") with open(str(path), "w") as f: f.write("NOT JSON, PROBABLY VALID YAML AM I RIGHT!?") - gc = GlobalConfig(config_dir=self._cfg_dir) - with self.assertRaises(ValueError): - gc.telemetry_enabled = True + gc = GlobalConfig() + gc.config_dir = Path(self._cfg_dir) + gc.telemetry_enabled = True + self.assertTrue(gc.telemetry_enabled) def test_setter_cannot_open_file(self): path = Path(self._cfg_dir, "metadata.json") with open(str(path), "w") as f: cfg = {"telemetryEnabled": True} f.write(json.dumps(cfg, indent=4) + "\n") - m = mock_open() - m.side_effect = IOError("fail") - gc = GlobalConfig(config_dir=self._cfg_dir) - with patch("samcli.cli.global_config.open", m): - with self.assertRaises(IOError): + m = MagicMock() + m.side_effect = OSError("fail") + gc = GlobalConfig() + gc.config_dir = Path(self._cfg_dir) + with patch("samcli.cli.global_config.Path.write_text", m): + with self.assertRaises(OSError): gc.telemetry_enabled = True diff --git a/tests/integration/buildcmd/test_build_cmd.py b/tests/integration/buildcmd/test_build_cmd.py index ae7af9d234..d8c021f379 100644 --- a/tests/integration/buildcmd/test_build_cmd.py +++ b/tests/integration/buildcmd/test_build_cmd.py @@ -1309,6 +1309,36 @@ def test_cache_build(self, use_container, code_uri, function1_handler, function2 expected_messages, command_result, self._make_parameter_override_arg(overrides) ) + @skipIf(SKIP_DOCKER_TESTS, SKIP_DOCKER_MESSAGE) + def test_cached_build_with_env_vars(self): + """ + Build 2 times to verify that second time hits the cached build + """ + overrides = { + "FunctionCodeUri": "Python", + "Function1Handler": "main.first_function_handler", + "Function2Handler": "main.second_function_handler", + "FunctionRuntime": "python3.8", + } + cmdlist = self.get_command_list( + use_container=True, parameter_overrides=overrides, cached=True, container_env_var="FOO=BAR" + ) + + LOG.info("Running Command (cache should be invalid): %s", cmdlist) + command_result = run_command(cmdlist, cwd=self.working_dir) + self.assertTrue( + "Cache is invalid, running build and copying resources to function build definition" + in command_result.stderr.decode("utf-8") + ) + + LOG.info("Re-Running Command (valid cache should exist): %s", cmdlist) + command_result_with_cache = run_command(cmdlist, cwd=self.working_dir) + + self.assertTrue( + "Valid cache found, copying previously built resources from function build definition" + in command_result_with_cache.stderr.decode("utf-8") + ) + @skipIf( ((IS_WINDOWS and RUNNING_ON_CI) and not CI_OVERRIDE), diff --git a/tests/integration/pipeline/test_init_command.py b/tests/integration/pipeline/test_init_command.py index 32706f3fe2..ba138e1588 100644 --- a/tests/integration/pipeline/test_init_command.py +++ b/tests/integration/pipeline/test_init_command.py @@ -7,9 +7,9 @@ from parameterized import parameterized -from samcli.cli.main import global_cfg from samcli.commands.pipeline.bootstrap.cli import PIPELINE_CONFIG_DIR, PIPELINE_CONFIG_FILENAME from samcli.commands.pipeline.init.interactive_init_flow import APP_PIPELINE_TEMPLATES_REPO_LOCAL_NAME +from samcli.cli.global_config import GlobalConfig from tests.integration.pipeline.base import InitIntegBase, BootstrapIntegBase from tests.integration.pipeline.test_bootstrap_command import SKIP_BOOTSTRAP_TESTS, CREDENTIAL_PROFILE from tests.testing_utils import run_command_with_inputs @@ -36,7 +36,7 @@ "prod-ecr", "us-west-2", ] -SHARED_PATH: Path = global_cfg.config_dir +SHARED_PATH: Path = GlobalConfig().config_dir EXPECTED_JENKINS_FILE_PATH = Path( SHARED_PATH, APP_PIPELINE_TEMPLATES_REPO_LOCAL_NAME, "tests", "testfile_jenkins", "expected" ) diff --git a/tests/integration/telemetry/integ_base.py b/tests/integration/telemetry/integ_base.py index 7776a14aae..c9d05a7f12 100644 --- a/tests/integration/telemetry/integ_base.py +++ b/tests/integration/telemetry/integ_base.py @@ -36,7 +36,8 @@ def setUp(self): self.maxDiff = None # Show full JSON Diff self.config_dir = tempfile.mkdtemp() - self._gc = GlobalConfig(config_dir=self.config_dir) + self._gc = GlobalConfig() + self._gc.config_dir = Path(self.config_dir) def tearDown(self): self.config_dir and shutil.rmtree(self.config_dir) @@ -49,9 +50,9 @@ def base_command(cls): return command - def run_cmd(self, stdin_data="", optout_envvar_value=None): + def run_cmd(self, cmd_list=None, stdin_data="", optout_envvar_value=None): # Any command will work for this test suite - cmd_list = [self.cmd, "local", "generate-event", "s3", "put"] + cmd_list = cmd_list or [self.cmd, "local", "generate-event", "s3", "put"] env = os.environ.copy() diff --git a/tests/integration/telemetry/test_experimental_metric.py b/tests/integration/telemetry/test_experimental_metric.py new file mode 100644 index 0000000000..980e9208b4 --- /dev/null +++ b/tests/integration/telemetry/test_experimental_metric.py @@ -0,0 +1,153 @@ +import os +import platform +import time +from unittest.mock import ANY + +from .integ_base import IntegBase, TelemetryServer +from samcli import __version__ as SAM_CLI_VERSION + + +class TestExperimentalMetric(IntegBase): + """ + Validates the basic tenets/contract Telemetry module needs to adhere to + """ + + def test_must_send_experimental_metrics_if_experimental_command(self): + """ + Metrics should be sent if "Disabled via config file but Enabled via Envvar" + """ + # Disable it via configuration file + self.unset_config() + self.set_config(telemetry_enabled=True) + os.environ["SAM_CLI_BETA_FEATURES"] = "0" + os.environ["SAM_CLI_BETA_ACCELERATE"] = "1" + + with TelemetryServer() as server: + # Run without any envvar.Should not publish metrics + process = self.run_cmd(cmd_list=[self.cmd, "traces", "--trace-id", "random-trace"], optout_envvar_value="1") + stdout, stderr = process.communicate() + + self.assertEqual(process.returncode, 1, "Command should fail") + print(stdout) + print(stderr) + all_requests = server.get_all_requests() + self.assertEqual(1, len(all_requests), "Command run metric must be sent") + request = all_requests[0] + self.assertIn("Content-Type", request["headers"]) + self.assertEqual(request["headers"]["Content-Type"], "application/json") + + expected_data = { + "metrics": [ + { + "commandRunExperimental": { + "requestId": ANY, + "installationId": self.get_global_config().installation_id, + "sessionId": ANY, + "executionEnvironment": ANY, + "ci": ANY, + "pyversion": ANY, + "samcliVersion": SAM_CLI_VERSION, + "awsProfileProvided": ANY, + "debugFlagProvided": ANY, + "region": ANY, + "commandName": ANY, + "metricSpecificAttributes": {"experimentalAccelerate": True, "experimentalAll": False}, + "duration": ANY, + "exitReason": ANY, + "exitCode": ANY, + } + } + ] + } + self.assertEqual(request["data"], expected_data) + + def test_must_send_experimental_metrics_if_experimental_option(self): + """ + Metrics should be sent if "Disabled via config file but Enabled via Envvar" + """ + # Disable it via configuration file + self.unset_config() + self.set_config(telemetry_enabled=True) + os.environ["SAM_CLI_BETA_FEATURES"] = "1" + + with TelemetryServer() as server: + # Run without any envvar.Should not publish metrics + process = self.run_cmd(cmd_list=[self.cmd, "logs", "--include-traces"], optout_envvar_value="1") + process.communicate() + + self.assertEqual(process.returncode, 1, "Command should fail") + all_requests = server.get_all_requests() + self.assertEqual(1, len(all_requests), "Command run metric must be sent") + request = all_requests[0] + self.assertIn("Content-Type", request["headers"]) + self.assertEqual(request["headers"]["Content-Type"], "application/json") + + expected_data = { + "metrics": [ + { + "commandRunExperimental": { + "requestId": ANY, + "installationId": self.get_global_config().installation_id, + "sessionId": ANY, + "executionEnvironment": ANY, + "ci": ANY, + "pyversion": ANY, + "samcliVersion": SAM_CLI_VERSION, + "awsProfileProvided": ANY, + "debugFlagProvided": ANY, + "region": ANY, + "commandName": ANY, + "metricSpecificAttributes": {"experimentalAccelerate": True, "experimentalAll": True}, + "duration": ANY, + "exitReason": ANY, + "exitCode": ANY, + } + } + ] + } + self.assertEqual(request["data"], expected_data) + + def test_must_send_not_experimental_metrics_if_not_experimental(self): + """ + Metrics should be sent if "Disabled via config file but Enabled via Envvar" + """ + # Disable it via configuration file + self.unset_config() + self.set_config(telemetry_enabled=True) + os.environ["SAM_CLI_BETA_FEATURES"] = "0" + + with TelemetryServer() as server: + # Run without any envvar.Should not publish metrics + process = self.run_cmd(cmd_list=[self.cmd, "logs"], optout_envvar_value="1") + process.communicate() + + self.assertEqual(process.returncode, 1, "Command should fail") + all_requests = server.get_all_requests() + self.assertEqual(1, len(all_requests), "Command run metric must be sent") + request = all_requests[0] + self.assertIn("Content-Type", request["headers"]) + self.assertEqual(request["headers"]["Content-Type"], "application/json") + + expected_data = { + "metrics": [ + { + "commandRun": { + "requestId": ANY, + "installationId": self.get_global_config().installation_id, + "sessionId": ANY, + "executionEnvironment": ANY, + "ci": ANY, + "pyversion": ANY, + "samcliVersion": SAM_CLI_VERSION, + "awsProfileProvided": ANY, + "debugFlagProvided": ANY, + "region": ANY, + "commandName": ANY, + "duration": ANY, + "exitReason": ANY, + "exitCode": ANY, + } + } + ] + } + self.assertEqual(request["data"], expected_data) diff --git a/tests/unit/cli/test_global_config.py b/tests/unit/cli/test_global_config.py index 5432488303..76bba95cd3 100644 --- a/tests/unit/cli/test_global_config.py +++ b/tests/unit/cli/test_global_config.py @@ -1,124 +1,280 @@ -from unittest.mock import mock_open, patch, Mock +import os +from unittest.mock import ANY, MagicMock, patch from unittest import TestCase -from parameterized import parameterized -from samcli.cli.global_config import GlobalConfig +from samcli.cli.global_config import ConfigEntry, DefaultEntry, GlobalConfig from pathlib import Path class TestGlobalConfig(TestCase): - def test_config_write_error(self): - m = mock_open() - m.side_effect = IOError("fail") - gc = GlobalConfig() - with patch("samcli.cli.global_config.open", m): - installation_id = gc.installation_id - self.assertIsNone(installation_id) - - def test_unable_to_create_dir(self): - m = mock_open() - m.side_effect = OSError("Permission DENIED") - gc = GlobalConfig() - with patch("samcli.cli.global_config.Path.mkdir", m): - installation_id = gc.installation_id - self.assertIsNone(installation_id) - telemetry_enabled = gc.telemetry_enabled - self.assertFalse(telemetry_enabled) - - def test_setter_cannot_open_path(self): - m = mock_open() - m.side_effect = IOError("fail") - gc = GlobalConfig() - with patch("samcli.cli.global_config.open", m): - with self.assertRaises(IOError): - gc.telemetry_enabled = True - - @patch("samcli.cli.global_config.click") - def test_config_dir_default(self, mock_click): - mock_click.get_app_dir.return_value = "mock/folders" - gc = GlobalConfig() - self.assertEqual(Path("mock/folders"), gc.config_dir) - mock_click.get_app_dir.assert_called_once_with("AWS SAM", force_posix=True) - - def test_explicit_installation_id(self): - gc = GlobalConfig(installation_id="foobar") - self.assertEqual("foobar", gc.installation_id) - - @patch("samcli.cli.global_config.uuid") - @patch("samcli.cli.global_config.Path") - @patch("samcli.cli.global_config.click") - def test_setting_installation_id(self, mock_click, mock_path, mock_uuid): - gc = GlobalConfig() - mock_uuid.uuid4.return_value = "SevenLayerDipMock" - path_mock = Mock() - joinpath_mock = Mock() - joinpath_mock.exists.return_value = False - path_mock.joinpath.return_value = joinpath_mock - mock_path.return_value = path_mock - mock_click.get_app_dir.return_value = "mock/folders" - mock_io = mock_open(Mock()) - with patch("samcli.cli.global_config.open", mock_io): - self.assertEqual("SevenLayerDipMock", gc.installation_id) - - def test_explicit_telemetry_enabled(self): - gc = GlobalConfig(telemetry_enabled=True) - self.assertTrue(gc.telemetry_enabled) - - @patch("samcli.cli.global_config.Path") - @patch("samcli.cli.global_config.click") - @patch("samcli.cli.global_config.os") - def test_missing_telemetry_flag(self, mock_os, mock_click, mock_path): - gc = GlobalConfig() - mock_click.get_app_dir.return_value = "mock/folders" - path_mock = Mock() - joinpath_mock = Mock() - joinpath_mock.exists.return_value = False - path_mock.joinpath.return_value = joinpath_mock - mock_path.return_value = path_mock - mock_os.environ = {} # env var is not set - self.assertIsNone(gc.telemetry_enabled) - - @patch("samcli.cli.global_config.Path") - @patch("samcli.cli.global_config.click") - @patch("samcli.cli.global_config.os") - def test_error_reading_telemetry_flag(self, mock_os, mock_click, mock_path): - gc = GlobalConfig() - mock_click.get_app_dir.return_value = "mock/folders" - path_mock = Mock() - joinpath_mock = Mock() - joinpath_mock.exists.return_value = True - path_mock.joinpath.return_value = joinpath_mock - mock_path.return_value = path_mock - mock_os.environ = {} # env var is not set - - m = mock_open() - m.side_effect = IOError("fail") - with patch("samcli.cli.global_config.open", m): - self.assertFalse(gc.telemetry_enabled) - - @parameterized.expand( - [ - # Only values of '1' and 1 will enable Telemetry. Everything will disable. - (1, True), - ("1", True), - (0, False), - ("0", False), - # words true, True, False, False etc will disable telemetry - ("true", False), - ("True", False), - ("False", False), - ] - ) - @patch("samcli.cli.global_config.os") - @patch("samcli.cli.global_config.click") - def test_set_telemetry_through_env_variable(self, env_value, expected_result, mock_click, mock_os): - gc = GlobalConfig() - - mock_os.environ = {"SAM_CLI_TELEMETRY": env_value} - mock_os.getenv.return_value = env_value - - self.assertEqual(gc.telemetry_enabled, expected_result) - - mock_os.getenv.assert_called_once_with("SAM_CLI_TELEMETRY") - - # When environment variable is set, we shouldn't be reading the real config file at all. - mock_click.get_app_dir.assert_not_called() + def setUp(self): + # Force singleton to recreate after each test + GlobalConfig._Singleton__instance = None + + path_write_patch = patch("samcli.cli.global_config.Path.write_text") + self.path_write_mock = path_write_patch.start() + self.addCleanup(path_write_patch.stop) + + path_read_patch = patch("samcli.cli.global_config.Path.read_text") + self.path_read_mock = path_read_patch.start() + self.addCleanup(path_read_patch.stop) + + path_exists_patch = patch("samcli.cli.global_config.Path.exists") + self.path_exists_mock = path_exists_patch.start() + self.path_exists_mock.return_value = True + self.addCleanup(path_exists_patch.stop) + + path_mkdir_patch = patch("samcli.cli.global_config.Path.mkdir") + self.path_mkdir_mock = path_mkdir_patch.start() + self.addCleanup(path_mkdir_patch.stop) + + json_patch = patch("samcli.cli.global_config.json") + self.json_mock = json_patch.start() + self.json_mock.loads.return_value = {} + self.json_mock.dumps.return_value = "{}" + self.addCleanup(json_patch.stop) + + click_patch = patch("samcli.cli.global_config.click") + self.click_mock = click_patch.start() + self.click_mock.get_app_dir.return_value = "app_dir" + self.addCleanup(click_patch.stop) + + threading_patch = patch("samcli.cli.global_config.threading") + self.threading_mock = threading_patch.start() + self.addCleanup(threading_patch.stop) + + self.patch_environ({}) + + def patch_environ(self, values): + environ_patch = patch.dict(os.environ, values, clear=True) + environ_patch.start() + self.addCleanup(environ_patch.stop) + + def tearDown(self): + # Force singleton to recreate after each test + GlobalConfig._Singleton__instance = None + + def test_singleton(self): + gc1 = GlobalConfig() + gc2 = GlobalConfig() + self.assertTrue(gc1 is gc2) + + def test_default_config_dir(self): + self.assertEqual(GlobalConfig().config_dir, Path("app_dir")) + + def test_inject_config_dir(self): + self.patch_environ({"__SAM_CLI_APP_DIR": "inject_dir"}) + self.assertEqual(GlobalConfig().config_dir, Path("inject_dir")) + + @patch("samcli.cli.global_config.Path.is_dir") + def test_set_config_dir(self, is_dir_mock): + is_dir_mock.return_value = True + GlobalConfig().config_dir = Path("new_app_dir") + self.assertEqual(GlobalConfig().config_dir, Path("new_app_dir")) + self.assertIsNone(GlobalConfig()._config_data) + + @patch("samcli.cli.global_config.Path.is_dir") + def test_set_config_dir_not_dir(self, is_dir_mock): + is_dir_mock.return_value = False + with self.assertRaises(ValueError): + GlobalConfig().config_dir = Path("new_app_dir") + self.assertEqual(GlobalConfig().config_dir, Path("app_dir")) + + def test_default_config_filename(self): + self.assertEqual(GlobalConfig().config_filename, "metadata.json") + + def test_set_config_filename(self): + GlobalConfig().config_filename = "new_metadata.json" + self.assertEqual(GlobalConfig().config_filename, "new_metadata.json") + self.assertIsNone(GlobalConfig()._config_data) + + def test_default_config_path(self): + self.assertEqual(GlobalConfig().config_path, Path("app_dir", "metadata.json")) + + def test_get_value_locking(self): + GlobalConfig()._get_value = MagicMock() + GlobalConfig().get_value(MagicMock(), True, object, False, True) + GlobalConfig()._access_lock.__enter__.assert_called_once() + GlobalConfig()._get_value.assert_called_once() + + def test_set_value_locking(self): + GlobalConfig()._set_value = MagicMock() + GlobalConfig().set_value(MagicMock(), MagicMock(), True, True) + GlobalConfig()._access_lock.__enter__.assert_called_once() + GlobalConfig()._set_value.assert_called_once() + + def test_get_value_env_var_only(self): + self.patch_environ({"ENV_VAR": "env_var_value"}) + result = GlobalConfig().get_value( + ConfigEntry(None, "ENV_VAR"), default="default", value_type=str, is_flag=False, reload_config=False + ) + self.assertEqual(result, "env_var_value") + + def test_get_value_env_var_and_config_priority(self): + self.patch_environ({"ENV_VAR": "env_var_value"}) + result = GlobalConfig().get_value( + ConfigEntry("config_key", "ENV_VAR"), default="default", value_type=str, is_flag=False, reload_config=False + ) + self.assertEqual(result, "env_var_value") + + def test_get_value_config_only(self): + self.patch_environ({"ENV_VAR": "env_var_value"}) + self.json_mock.loads.return_value = {"config_key": "config_value"} + result = GlobalConfig().get_value( + ConfigEntry("config_key", None), default="default", value_type=str, is_flag=False, reload_config=False + ) + self.assertEqual(result, "config_value") + + def test_get_value_error_default(self): + self.patch_environ({"ENV_VAR": "env_var_value"}) + self.json_mock.loads.side_effect = ValueError() + result = GlobalConfig().get_value( + ConfigEntry("config_key", None), default="default", value_type=str, is_flag=False, reload_config=False + ) + self.assertEqual(result, "default") + + def test_get_value_incorrect_type_default(self): + self.patch_environ({"ENV_VAR": "env_var_value"}) + self.json_mock.loads.return_value = {"config_key": 1} + result = GlobalConfig().get_value( + ConfigEntry("config_key", None), default="default", value_type=str, is_flag=True, reload_config=False + ) + self.assertEqual(result, "default") + + def test_get_value_flag_env_var_True(self): + self.patch_environ({"ENV_VAR": "1"}) + self.json_mock.loads.return_value = {"config_key": False} + result = GlobalConfig().get_value( + ConfigEntry("config_key", "ENV_VAR"), default=False, value_type=bool, is_flag=True, reload_config=False + ) + self.assertTrue(result) + + def test_get_value_flag_env_var_False(self): + self.patch_environ({"ENV_VAR": "0"}) + self.json_mock.loads.return_value = {"config_key": True} + result = GlobalConfig().get_value( + ConfigEntry("config_key", "ENV_VAR"), default=True, value_type=bool, is_flag=True, reload_config=False + ) + self.assertFalse(result) + + def test_get_value_flag_config_True(self): + self.json_mock.loads.return_value = {"config_key": True} + result = GlobalConfig().get_value( + ConfigEntry("config_key", "ENV_VAR"), default=False, value_type=bool, is_flag=True, reload_config=False + ) + self.assertTrue(result) + + def test_set_value(self): + self.patch_environ({"ENV_VAR": "env_var_value"}) + GlobalConfig().set_value(ConfigEntry("config_key", "ENV_VAR"), "value", False, True) + self.assertEqual(os.environ["ENV_VAR"], "value") + self.assertEqual(GlobalConfig()._config_data["config_key"], "value") + self.json_mock.dumps.assert_called_once_with({"config_key": "value"}, indent=ANY) + self.path_write_mock.assert_called_once() + + def test_set_value_non_persistent(self): + self.patch_environ({"ENV_VAR": "env_var_value"}) + GlobalConfig().set_value(ConfigEntry("config_key", "ENV_VAR", False), "value", False, True) + self.assertEqual(os.environ["ENV_VAR"], "value") + self.assertEqual(GlobalConfig()._config_data["config_key"], "value") + self.json_mock.dumps.assert_called_once_with({}, indent=ANY) + self.path_write_mock.assert_called_once() + + def test_set_value_no_flush(self): + self.patch_environ({"ENV_VAR": "env_var_value"}) + GlobalConfig().set_value(ConfigEntry("config_key", "ENV_VAR"), "value", False, False) + self.assertEqual(os.environ["ENV_VAR"], "value") + self.assertEqual(GlobalConfig()._config_data["config_key"], "value") + self.json_mock.dumps.assert_not_called() + self.path_write_mock.assert_not_called() + + def test_set_value_flag_true(self): + self.patch_environ({"ENV_VAR": "env_var_value"}) + GlobalConfig().set_value(ConfigEntry("config_key", "ENV_VAR"), True, True, True) + self.assertEqual(os.environ["ENV_VAR"], "1") + self.assertEqual(GlobalConfig()._config_data["config_key"], True) + self.json_mock.dumps.assert_called_once() + self.path_write_mock.assert_called_once() + + def test_set_value_flag_false(self): + self.patch_environ({"ENV_VAR": "env_var_value"}) + GlobalConfig().set_value(ConfigEntry("config_key", "ENV_VAR"), False, True, True) + self.assertEqual(os.environ["ENV_VAR"], "0") + self.assertEqual(GlobalConfig()._config_data["config_key"], False) + self.json_mock.dumps.assert_called_once() + self.path_write_mock.assert_called_once() + + def test_load_config(self): + self.path_exists_mock.return_value = True + self.json_mock.loads.return_value = {"a": "b"} + self.assertIsNone(GlobalConfig()._config_data) + GlobalConfig()._load_config() + self.assertEqual(GlobalConfig()._config_data, {"a": "b"}) + + def test_load_config_file_does_not_exist(self): + self.path_exists_mock.return_value = False + self.json_mock.loads.return_value = {"a": "b"} + self.assertIsNone(GlobalConfig()._config_data) + GlobalConfig()._load_config() + self.assertEqual(GlobalConfig()._config_data, {}) + + def test_load_config_error(self): + self.path_exists_mock.return_value = True + self.json_mock.loads.return_value = {"a": "b"} + self.json_mock.loads.side_effect = ValueError() + self.assertIsNone(GlobalConfig()._config_data) + GlobalConfig()._load_config() + self.assertEqual(GlobalConfig()._config_data, {}) + + def test_write_config(self): + self.path_exists_mock.return_value = False + GlobalConfig()._persistent_fields = ["a"] + GlobalConfig()._config_data = {"a": 1} + GlobalConfig()._write_config() + self.json_mock.dumps.assert_called_once() + self.path_mkdir_mock.assert_called_once() + self.path_write_mock.assert_called_once() + + @patch("samcli.cli.global_config.uuid.uuid4") + def test_get_installation_id_saved(self, uuid_mock): + self.json_mock.loads.return_value = {DefaultEntry.INSTALLATION_ID.config_key: "saved_uuid"} + uuid_mock.return_value = "default_uuid" + result = GlobalConfig().installation_id + self.assertEqual(result, "saved_uuid") + + @patch("samcli.cli.global_config.uuid.uuid4") + def test_get_installation_id_default(self, uuid_mock): + self.json_mock.loads.return_value = {} + uuid_mock.return_value = "default_uuid" + result = GlobalConfig().installation_id + self.assertEqual(result, "default_uuid") + + def test_get_telemetry_enabled(self): + self.patch_environ({DefaultEntry.TELEMETRY.env_var_key: "1"}) + self.json_mock.loads.return_value = {DefaultEntry.TELEMETRY.config_key: True} + result = GlobalConfig().telemetry_enabled + self.assertEqual(result, True) + + def test_get_telemetry_disabled(self): + self.patch_environ({DefaultEntry.TELEMETRY.env_var_key: "0"}) + self.json_mock.loads.return_value = {DefaultEntry.TELEMETRY.config_key: True} + result = GlobalConfig().telemetry_enabled + self.assertEqual(result, False) + + def test_get_telemetry_default(self): + self.patch_environ({"__SAM_CLI_APP_DIR": "inject_dir"}) + result = GlobalConfig().telemetry_enabled + self.assertIsNone(result) + + def test_set_telemetry(self): + GlobalConfig().telemetry_enabled = True + self.assertEqual(os.environ[DefaultEntry.TELEMETRY.env_var_key], "1") + self.assertEqual(GlobalConfig()._config_data[DefaultEntry.TELEMETRY.config_key], True) + + def test_get_last_version_check(self): + self.json_mock.loads.return_value = {DefaultEntry.LAST_VERSION_CHECK.config_key: 123.4} + result = GlobalConfig().last_version_check + self.assertEqual(result, 123.4) + + def test_set_last_version_check(self): + GlobalConfig().last_version_check = 123.4 + self.assertEqual(GlobalConfig()._config_data[DefaultEntry.LAST_VERSION_CHECK.config_key], 123.4) diff --git a/tests/unit/cli/test_main.py b/tests/unit/cli/test_main.py index bfa770d88b..12cd40baee 100644 --- a/tests/unit/cli/test_main.py +++ b/tests/unit/cli/test_main.py @@ -12,7 +12,7 @@ def test_cli_base(self): :return: """ mock_cfg = Mock() - with patch("samcli.cli.main.global_cfg", mock_cfg): + with patch("samcli.cli.main.GlobalConfig", mock_cfg): runner = CliRunner() result = runner.invoke(cli, []) self.assertEqual(result.exit_code, 0) @@ -21,14 +21,14 @@ def test_cli_base(self): def test_cli_some_command(self): mock_cfg = Mock() - with patch("samcli.cli.main.global_cfg", mock_cfg): + with patch("samcli.cli.main.GlobalConfig", mock_cfg): runner = CliRunner() result = runner.invoke(cli, ["local", "generate-event", "s3"]) self.assertEqual(result.exit_code, 0) def test_cli_with_debug(self): mock_cfg = Mock() - with patch("samcli.cli.main.global_cfg", mock_cfg): + with patch("samcli.cli.main.GlobalConfig", mock_cfg): runner = CliRunner() result = runner.invoke(cli, ["local", "generate-event", "s3", "put", "--debug"]) self.assertEqual(result.exit_code, 0) diff --git a/tests/unit/commands/_utils/test_experimental.py b/tests/unit/commands/_utils/test_experimental.py new file mode 100644 index 0000000000..52f29a2ad4 --- /dev/null +++ b/tests/unit/commands/_utils/test_experimental.py @@ -0,0 +1,110 @@ +import os +from unittest.mock import MagicMock, call, patch +from unittest import TestCase + +from samcli.commands._utils.experimental import ( + _experimental_option_callback, + disable_all_experimental, + force_experimental_option, + get_all_experimental_statues, + get_all_experimental, + is_experimental_enabled, + prompt_experimental, + set_experimental, +) + + +class TestExperimental(TestCase): + def setUp(self): + + gc_patch = patch("samcli.commands._utils.experimental.GlobalConfig") + self.gc_mock = gc_patch.start() + self.addCleanup(gc_patch.stop) + + self.patch_environ({}) + + def patch_environ(self, values): + environ_patch = patch.dict(os.environ, values, clear=True) + environ_patch.start() + self.addCleanup(environ_patch.stop) + + def tearDown(self): + pass + + def test_is_experimental_enabled(self): + config_entry = MagicMock() + self.gc_mock.return_value.get_value.side_effect = [False, True] + result = is_experimental_enabled(config_entry) + self.assertTrue(result) + + def test_is_experimental_enabled_all(self): + config_entry = MagicMock() + self.gc_mock.return_value.get_value.side_effect = [True, False] + result = is_experimental_enabled(config_entry) + self.assertTrue(result) + + def test_is_experimental_enabled_false(self): + config_entry = MagicMock() + self.gc_mock.return_value.get_value.side_effect = [False, False] + result = is_experimental_enabled(config_entry) + self.assertFalse(result) + + def test_set_experimental(self): + config_entry = MagicMock() + set_experimental(config_entry, False) + self.gc_mock.return_value.set_value.assert_called_once_with(config_entry, False, is_flag=True, flush=False) + + def test_get_all_experimental(self): + self.assertEqual(len(get_all_experimental()), 2) + + def test_get_all_experimental_statues(self): + self.assertEqual(len(get_all_experimental_statues()), 2) + + @patch("samcli.commands._utils.experimental.set_experimental") + @patch("samcli.commands._utils.experimental.get_all_experimental") + def test_disable_all_experimental(self, get_all_experimental_mock, set_experimental_mock): + flags = [MagicMock(), MagicMock(), MagicMock()] + get_all_experimental_mock.return_value = flags + disable_all_experimental() + set_experimental_mock.assert_has_calls([call(flags[0], False), call(flags[1], False), call(flags[2], False)]) + + @patch("samcli.commands._utils.experimental.set_experimental") + @patch("samcli.commands._utils.experimental.disable_all_experimental") + def test_experimental_option_callback_true(self, disable_all_experimental_mock, set_experimental_mock): + _experimental_option_callback(MagicMock(), MagicMock(), True) + set_experimental_mock.assert_called_once() + disable_all_experimental_mock.assert_not_called() + + @patch("samcli.commands._utils.experimental.set_experimental") + @patch("samcli.commands._utils.experimental.disable_all_experimental") + def test_experimental_option_callback_false(self, disable_all_experimental_mock, set_experimental_mock): + _experimental_option_callback(MagicMock(), MagicMock(), False) + set_experimental_mock.assert_not_called() + disable_all_experimental_mock.assert_called_once() + + @patch("samcli.commands._utils.experimental.Context") + @patch("samcli.commands._utils.experimental.prompt_experimental") + def test_force_experimental_option_true(self, prompt_experimental_mock, context_mock): + config_entry = MagicMock() + prompt = "abc" + prompt_experimental_mock.return_value = True + + @force_experimental_option("param", config_entry, prompt) + def func(param=None): + self.assertEqual(param, 1) + + func(param=1) + prompt_experimental_mock.assert_called_once_with(config_entry=config_entry, prompt=prompt) + + @patch("samcli.commands._utils.experimental.set_experimental") + @patch("samcli.commands._utils.experimental.click.confirm") + @patch("samcli.commands._utils.experimental.is_experimental_enabled") + def test_prompt_experimental(self, enabled_mock, confirm_mock, set_experimental_mock): + config_entry = MagicMock() + prompt = "abc" + enabled_mock.return_value = False + confirm_mock.return_value = True + prompt_experimental(config_entry, prompt) + set_experimental_mock.assert_called_once_with(config_entry=config_entry, enabled=True) + enabled_mock.assert_called_once_with(config_entry) + confirm_mock.assert_called_once_with(prompt, default=False) diff --git a/tests/unit/commands/_utils/test_options.py b/tests/unit/commands/_utils/test_options.py index ea82e5cdbf..02240c403d 100644 --- a/tests/unit/commands/_utils/test_options.py +++ b/tests/unit/commands/_utils/test_options.py @@ -17,6 +17,7 @@ _TEMPLATE_OPTION_DEFAULT_VALUE, guided_deploy_stack_name, artifact_callback, + parameterized_option, resolve_s3_callback, image_repositories_callback, _space_separated_list_func_type, @@ -463,3 +464,33 @@ class TestSpaceSeparatedListInvalidDataTypes: def test_raise_value_error(self, test_input): with pytest.raises(ValueError): _space_separated_list_func_type(test_input) + + +class TestParameterizedOption(TestCase): + @parameterized_option + def option_dec_with_value(f, value=2): + def wrapper(): + return f(value) + + return wrapper + + @parameterized_option + def option_dec_without_value(f, value=2): + def wrapper(): + return f(value) + + return wrapper + + @option_dec_with_value(5) + def some_function_with_value(value): + return value + 2 + + @option_dec_without_value + def some_function_without_value(value): + return value + 2 + + def test_option_dec_with_value(self): + self.assertEqual(TestParameterizedOption.some_function_with_value(), 7) + + def test_option_dec_without_value(self): + self.assertEqual(TestParameterizedOption.some_function_without_value(), 4) diff --git a/tests/unit/commands/_utils/test_template.py b/tests/unit/commands/_utils/test_template.py index 1de707ec38..c75db92c67 100644 --- a/tests/unit/commands/_utils/test_template.py +++ b/tests/unit/commands/_utils/test_template.py @@ -7,7 +7,7 @@ from botocore.utils import set_value_from_jmespath from parameterized import parameterized, param -from samcli.commands._utils.resources import AWS_SERVERLESS_FUNCTION, AWS_SERVERLESS_API +from samcli.lib.utils.resources import AWS_SERVERLESS_FUNCTION, AWS_SERVERLESS_API from samcli.commands._utils.template import ( get_template_data, METADATA_WITH_LOCAL_PATHS, diff --git a/tests/unit/commands/buildcmd/test_build_context.py b/tests/unit/commands/buildcmd/test_build_context.py index 3ab805a7ee..c30e17cb20 100644 --- a/tests/unit/commands/buildcmd/test_build_context.py +++ b/tests/unit/commands/buildcmd/test_build_context.py @@ -1,12 +1,28 @@ import os +from samcli.lib.build.app_builder import ApplicationBuilder, ApplicationBuildResult from unittest import TestCase -from unittest.mock import patch, Mock, ANY +from unittest.mock import patch, Mock, ANY, call from parameterized import parameterized +from samcli.lib.build.build_graph import DEFAULT_DEPENDENCIES_DIR +from samcli.lib.utils.osutils import BUILD_DIR_PERMISSIONS from samcli.local.lambdafn.exceptions import ResourceNotFound from samcli.commands.build.build_context import BuildContext from samcli.commands.build.exceptions import InvalidBuildDirException, MissingBuildMethodException +from samcli.commands.exceptions import UserException +from samcli.lib.build.app_builder import ( + BuildError, + UnsupportedBuilderLibraryVersionError, + BuildInsideContainerError, + ContainerBuildNotSupported, +) +from samcli.lib.build.workflow_config import UnsupportedRuntimeException +from samcli.local.lambdafn.exceptions import FunctionNotFound + + +class DeepWrap(Exception): + pass class TestBuildContext__enter__(TestCase): @@ -56,6 +72,7 @@ def test_must_setup_context( mode="buildmode", cached=False, cache_dir="cache_dir", + parallel=True, aws_region="any_aws_region", ) setup_build_dir_mock = Mock() @@ -134,6 +151,7 @@ def test_must_fail_with_illegal_identifier( mode="buildmode", cached=False, cache_dir="cache_dir", + parallel=True, ) setup_build_dir_mock = Mock() build_dir_result = setup_build_dir_mock.return_value = "my/new/build/dir" @@ -187,6 +205,7 @@ def test_must_return_only_layer_when_layer_is_build( mode="buildmode", cached=False, cache_dir="cache_dir", + parallel=True, ) setup_build_dir_mock = Mock() build_dir_result = setup_build_dir_mock.return_value = "my/new/build/dir" @@ -242,6 +261,7 @@ def test_must_return_buildable_dependent_layer_when_function_is_build( mode="buildmode", cached=False, cache_dir="cache_dir", + parallel=True, ) setup_build_dir_mock = Mock() build_dir_result = setup_build_dir_mock.return_value = "my/new/build/dir" @@ -297,6 +317,7 @@ def test_must_fail_when_layer_is_build_without_buildmethod( mode="buildmode", cached=False, cache_dir="cache_dir", + parallel=True, ) setup_build_dir_mock = Mock() build_dir_result = setup_build_dir_mock.return_value = "my/new/build/dir" @@ -365,6 +386,7 @@ def test_must_return_many_functions_to_build( mode="buildmode", cached=False, cache_dir="cache_dir", + parallel=True, ) setup_build_dir_mock = Mock() build_dir_result = setup_build_dir_mock.return_value = "my/new/build/dir" @@ -430,6 +452,7 @@ def test_must_print_remote_url_warning( mode="buildmode", cached=False, cache_dir="cache_dir", + parallel=True, ) context._setup_build_dir = Mock() @@ -566,6 +589,335 @@ def test_when_build_dir_is_cwd_raises_exception(self, pathlib_patch, os_patch, s pathlib_patch.Path.cwd.assert_called_once() +class TestBuildContext_setup_cached_and_deps_dir(TestCase): + @parameterized.expand([(True,), (False,)]) + @patch("samcli.commands.build.build_context.pathlib.Path") + @patch("samcli.commands.build.build_context.SamLocalStackProvider") + @patch("samcli.commands.build.build_context.SamFunctionProvider") + @patch("samcli.commands.build.build_context.SamLayerProvider") + def test_cached_dir_and_deps_dir_creation( + self, cached, patched_layer, patched_function, patched_stack, patched_path + ): + patched_stack.get_stacks.return_value = ([], None) + build_context = BuildContext( + resource_identifier="function_identifier", + template_file="template_file", + base_dir="base_dir", + build_dir="build_dir", + cache_dir="cache_dir", + parallel=False, + mode="mode", + cached=cached, + ) + + with patch.object(build_context, "_setup_build_dir"): + build_context.set_up() + + call_assertion = lambda: patched_path.assert_has_calls( + [ + call("cache_dir"), + call().mkdir(exist_ok=True, mode=BUILD_DIR_PERMISSIONS, parents=True), + call(DEFAULT_DEPENDENCIES_DIR), + call().mkdir(exist_ok=True, mode=BUILD_DIR_PERMISSIONS, parents=True), + ], + any_order=True, + ) + + # if it is cached validate calls above is made, + # otherwise validate an assertion will be raised since they are not called + if cached: + call_assertion() + else: + with self.assertRaises(AssertionError): + call_assertion() + + +class TestBuildContext_run(TestCase): + @patch("samcli.commands.build.build_context.SamLocalStackProvider.get_stacks") + @patch("samcli.commands.build.build_context.SamFunctionProvider") + @patch("samcli.commands.build.build_context.SamLayerProvider") + @patch("samcli.commands.build.build_context.pathlib") + @patch("samcli.commands.build.build_context.ContainerManager") + @patch("samcli.commands.build.build_context.BuildContext._setup_build_dir") + @patch("samcli.commands.build.build_context.ApplicationBuilder") + @patch("samcli.commands.build.build_context.BuildContext.get_resources_to_build") + @patch("samcli.commands.build.build_context.move_template") + @patch("samcli.commands.build.build_context.os") + def test_run_build_context( + self, + os_mock, + move_template_mock, + resources_mock, + ApplicationBuilderMock, + build_dir_mock, + ContainerManagerMock, + pathlib_mock, + SamLayerProviderMock, + SamFunctionProviderMock, + get_buildable_stacks_mock, + ): + + root_stack = Mock() + root_stack.is_root_stack = True + auto_dependency_layer = False + root_stack.get_output_template_path = Mock(return_value="./build_dir/template.yaml") + child_stack = Mock() + child_stack.get_output_template_path = Mock(return_value="./build_dir/abcd/template.yaml") + stack_output_template_path_by_stack_path = { + root_stack.stack_path: "./build_dir/template.yaml", + child_stack.stack_path: "./build_dir/abcd/template.yaml", + } + resources_mock.return_value = Mock() + + builder_mock = ApplicationBuilderMock.return_value = Mock() + artifacts = "artifacts" + builder_mock.build.return_value = ApplicationBuildResult(Mock(), artifacts) + modified_template_root = "modified template 1" + modified_template_child = "modified template 2" + builder_mock.update_template.side_effect = [modified_template_root, modified_template_child] + + get_buildable_stacks_mock.return_value = ([root_stack, child_stack], []) + layer1 = DummyLayer("layer1", "python3.8") + layer_provider_mock = Mock() + layer_provider_mock.get.return_value = layer1 + layerprovider = SamLayerProviderMock.return_value = layer_provider_mock + func1 = DummyFunction("func1", [layer1]) + func_provider_mock = Mock() + func_provider_mock.get.return_value = func1 + funcprovider = SamFunctionProviderMock.return_value = func_provider_mock + base_dir = pathlib_mock.Path.return_value.resolve.return_value.parent = "basedir" + container_mgr_mock = ContainerManagerMock.return_value = Mock() + build_dir_mock.return_value = "build_dir" + + with BuildContext( + resource_identifier="function_identifier", + template_file="template_file", + base_dir="base_dir", + build_dir="build_dir", + cache_dir="cache_dir", + cached=False, + clean="clean", + use_container=False, + parallel="parallel", + parameter_overrides="parameter_overrides", + manifest_path="manifest_path", + docker_network="docker_network", + skip_pull_image="skip_pull_image", + mode="mode", + container_env_var={}, + container_env_var_file=None, + build_images={}, + create_auto_dependency_layer=auto_dependency_layer, + ) as build_context: + build_context.run() + + ApplicationBuilderMock.assert_called_once_with( + ANY, + build_context.build_dir, + build_context.base_dir, + build_context.cache_dir, + build_context.cached, + build_context.is_building_specific_resource, + manifest_path_override=build_context.manifest_path_override, + container_manager=build_context.container_manager, + mode=build_context.mode, + parallel=build_context._parallel, + container_env_var=build_context._container_env_var, + container_env_var_file=build_context._container_env_var_file, + build_images=build_context._build_images, + combine_dependencies=not auto_dependency_layer, + ) + builder_mock.build.assert_called_once() + builder_mock.update_template.assert_has_calls( + [ + call( + root_stack, + artifacts, + stack_output_template_path_by_stack_path, + ) + ], + [ + call( + child_stack, + artifacts, + stack_output_template_path_by_stack_path, + ) + ], + ) + move_template_mock.assert_has_calls( + [ + call( + root_stack.location, + stack_output_template_path_by_stack_path[root_stack.stack_path], + modified_template_root, + ), + call( + child_stack.location, + stack_output_template_path_by_stack_path[child_stack.stack_path], + modified_template_child, + ), + ] + ) + + @parameterized.expand( + [ + (UnsupportedRuntimeException(), "UnsupportedRuntimeException"), + (BuildInsideContainerError(), "BuildInsideContainerError"), + (BuildError(wrapped_from=DeepWrap().__class__.__name__, msg="Test"), "DeepWrap"), + (ContainerBuildNotSupported(), "ContainerBuildNotSupported"), + ( + UnsupportedBuilderLibraryVersionError(container_name="name", error_msg="msg"), + "UnsupportedBuilderLibraryVersionError", + ), + ] + ) + @patch("samcli.commands.build.build_context.SamLocalStackProvider.get_stacks") + @patch("samcli.commands.build.build_context.SamFunctionProvider") + @patch("samcli.commands.build.build_context.SamLayerProvider") + @patch("samcli.commands.build.build_context.pathlib") + @patch("samcli.commands.build.build_context.ContainerManager") + @patch("samcli.commands.build.build_context.BuildContext._setup_build_dir") + @patch("samcli.commands.build.build_context.ApplicationBuilder") + @patch("samcli.commands.build.build_context.BuildContext.get_resources_to_build") + @patch("samcli.commands.build.build_context.move_template") + @patch("samcli.commands.build.build_context.os") + def test_must_catch_known_exceptions( + self, + exception, + wrapped_exception, + os_mock, + move_template_mock, + resources_mock, + ApplicationBuilderMock, + build_dir_mock, + ContainerManagerMock, + pathlib_mock, + SamLayerProviderMock, + SamFunctionProviderMock, + get_buildable_stacks_mock, + ): + + stack = Mock() + resources_mock.return_value = Mock() + + builder_mock = ApplicationBuilderMock.return_value = Mock() + artifacts = builder_mock.build.return_value = "artifacts" + modified_template_root = "modified template 1" + modified_template_child = "modified template 2" + builder_mock.update_template.side_effect = [modified_template_root, modified_template_child] + + get_buildable_stacks_mock.return_value = ([stack], []) + layer1 = DummyLayer("layer1", "python3.8") + layer_provider_mock = Mock() + layer_provider_mock.get.return_value = layer1 + layerprovider = SamLayerProviderMock.return_value = layer_provider_mock + func1 = DummyFunction("func1", [layer1]) + func_provider_mock = Mock() + func_provider_mock.get.return_value = func1 + funcprovider = SamFunctionProviderMock.return_value = func_provider_mock + base_dir = pathlib_mock.Path.return_value.resolve.return_value.parent = "basedir" + container_mgr_mock = ContainerManagerMock.return_value = Mock() + build_dir_mock.return_value = "build_dir" + + builder_mock.build.side_effect = exception + + with self.assertRaises(UserException) as ctx: + with BuildContext( + resource_identifier="function_identifier", + template_file="template_file", + base_dir="base_dir", + build_dir="build_dir", + cache_dir="cache_dir", + cached=False, + clean="clean", + use_container=False, + parallel="parallel", + parameter_overrides="parameter_overrides", + manifest_path="manifest_path", + docker_network="docker_network", + skip_pull_image="skip_pull_image", + mode="mode", + container_env_var={}, + container_env_var_file=None, + build_images={}, + ) as build_context: + build_context.run() + + self.assertEqual(str(ctx.exception), str(exception)) + self.assertEqual(wrapped_exception, ctx.exception.wrapped_from) + + @patch("samcli.commands.build.build_context.SamLocalStackProvider.get_stacks") + @patch("samcli.commands.build.build_context.SamFunctionProvider") + @patch("samcli.commands.build.build_context.SamLayerProvider") + @patch("samcli.commands.build.build_context.pathlib") + @patch("samcli.commands.build.build_context.ContainerManager") + @patch("samcli.commands.build.build_context.BuildContext._setup_build_dir") + @patch("samcli.commands.build.build_context.ApplicationBuilder") + @patch("samcli.commands.build.build_context.BuildContext.get_resources_to_build") + @patch("samcli.commands.build.build_context.move_template") + @patch("samcli.commands.build.build_context.os") + def test_must_catch_function_not_found_exception( + self, + os_mock, + move_template_mock, + resources_mock, + ApplicationBuilderMock, + build_dir_mock, + ContainerManagerMock, + pathlib_mock, + SamLayerProviderMock, + SamFunctionProviderMock, + get_buildable_stacks_mock, + ): + stack = Mock() + resources_mock.return_value = Mock() + + builder_mock = ApplicationBuilderMock.return_value = Mock() + artifacts = builder_mock.build.return_value = "artifacts" + modified_template_root = "modified template 1" + modified_template_child = "modified template 2" + builder_mock.update_template.side_effect = [modified_template_root, modified_template_child] + + get_buildable_stacks_mock.return_value = ([stack], []) + layer1 = DummyLayer("layer1", "python3.8") + layer_provider_mock = Mock() + layer_provider_mock.get.return_value = layer1 + layerprovider = SamLayerProviderMock.return_value = layer_provider_mock + func1 = DummyFunction("func1", [layer1]) + func_provider_mock = Mock() + func_provider_mock.get.return_value = func1 + funcprovider = SamFunctionProviderMock.return_value = func_provider_mock + base_dir = pathlib_mock.Path.return_value.resolve.return_value.parent = "basedir" + container_mgr_mock = ContainerManagerMock.return_value = Mock() + build_dir_mock.return_value = "build_dir" + + ApplicationBuilderMock.side_effect = FunctionNotFound("Function Not Found") + + with self.assertRaises(UserException) as ctx: + with BuildContext( + resource_identifier="function_identifier", + template_file="template_file", + base_dir="base_dir", + build_dir="build_dir", + cache_dir="cache_dir", + cached=False, + clean="clean", + use_container=False, + parallel="parallel", + parameter_overrides="parameter_overrides", + manifest_path="manifest_path", + docker_network="docker_network", + skip_pull_image="skip_pull_image", + mode="mode", + container_env_var={}, + container_env_var_file=None, + build_images={}, + ) as build_context: + build_context.run() + + self.assertEqual(str(ctx.exception), "Function Not Found") + + class DummyLayer: def __init__(self, name, build_method, codeuri="layer_src"): self.name = name diff --git a/tests/unit/commands/buildcmd/test_command.py b/tests/unit/commands/buildcmd/test_command.py index 3d6f296d0a..3cc894d03c 100644 --- a/tests/unit/commands/buildcmd/test_command.py +++ b/tests/unit/commands/buildcmd/test_command.py @@ -2,53 +2,20 @@ import click from unittest import TestCase -from unittest.mock import Mock, patch, call +from unittest.mock import Mock, patch from parameterized import parameterized from samcli.commands.build.command import do_cli, _get_mode_value_from_envvar, _process_env_var, _process_image_options -from samcli.commands.exceptions import UserException -from samcli.lib.build.app_builder import ( - BuildError, - UnsupportedBuilderLibraryVersionError, - BuildInsideContainerError, - ContainerBuildNotSupported, -) -from samcli.lib.build.workflow_config import UnsupportedRuntimeException -from samcli.local.lambdafn.exceptions import FunctionNotFound - - -class DeepWrap(Exception): - pass class TestDoCli(TestCase): + @patch("samcli.commands.build.command.click") @patch("samcli.commands.build.build_context.BuildContext") - @patch("samcli.lib.build.app_builder.ApplicationBuilder") - @patch("samcli.commands._utils.template.move_template") @patch("samcli.commands.build.command.os") - def test_must_succeed_build(self, os_mock, move_template_mock, ApplicationBuilderMock, BuildContextMock): + def test_must_succeed_build(self, os_mock, BuildContextMock, mock_build_click): ctx_mock = Mock() - - # create stack mocks - root_stack = Mock() - root_stack.is_root_stack = True - root_stack.get_output_template_path = Mock(return_value="./build_dir/template.yaml") - child_stack = Mock() - child_stack.get_output_template_path = Mock(return_value="./build_dir/abcd/template.yaml") - ctx_mock.stacks = [root_stack, child_stack] - stack_output_template_path_by_stack_path = { - root_stack.stack_path: "./build_dir/template.yaml", - child_stack.stack_path: "./build_dir/abcd/template.yaml", - } - - BuildContextMock.return_value.__enter__ = Mock() BuildContextMock.return_value.__enter__.return_value = ctx_mock - builder_mock = ApplicationBuilderMock.return_value = Mock() - artifacts = builder_mock.build.return_value = "artifacts" - modified_template_root = "modified template 1" - modified_template_child = "modified template 2" - builder_mock.update_template.side_effect = [modified_template_root, modified_template_child] do_cli( ctx_mock, @@ -63,7 +30,7 @@ def test_must_succeed_build(self, os_mock, move_template_mock, ApplicationBuilde "parallel", "manifest_path", "docker_network", - "skip_pull", + "skip_pull_image", "parameter_overrides", "mode", (""), @@ -71,132 +38,28 @@ def test_must_succeed_build(self, os_mock, move_template_mock, ApplicationBuilde (), ) - ApplicationBuilderMock.assert_called_once_with( - ctx_mock.resources_to_build, - ctx_mock.build_dir, - ctx_mock.base_dir, - ctx_mock.cache_dir, - ctx_mock.cached, - ctx_mock.is_building_specific_resource, - manifest_path_override=ctx_mock.manifest_path_override, - container_manager=ctx_mock.container_manager, - mode=ctx_mock.mode, + BuildContextMock.assert_called_with( + "function_identifier", + "template", + "base_dir", + "build_dir", + "cache_dir", + "cached", + clean="clean", + use_container="use_container", parallel="parallel", + parameter_overrides="parameter_overrides", + manifest_path="manifest_path", + docker_network="docker_network", + skip_pull_image="skip_pull_image", + mode="mode", container_env_var={}, container_env_var_file="container_env_var_file", build_images={}, + aws_region=ctx_mock.region, ) - builder_mock.build.assert_called_once() - builder_mock.update_template.assert_has_calls( - [ - call( - root_stack, - artifacts, - stack_output_template_path_by_stack_path, - ) - ], - [ - call( - child_stack, - artifacts, - stack_output_template_path_by_stack_path, - ) - ], - ) - move_template_mock.assert_has_calls( - [ - call( - root_stack.location, - stack_output_template_path_by_stack_path[root_stack.stack_path], - modified_template_root, - ), - call( - child_stack.location, - stack_output_template_path_by_stack_path[child_stack.stack_path], - modified_template_child, - ), - ] - ) - - @parameterized.expand( - [ - (UnsupportedRuntimeException(), "UnsupportedRuntimeException"), - (BuildInsideContainerError(), "BuildInsideContainerError"), - (BuildError(wrapped_from=DeepWrap().__class__.__name__, msg="Test"), "DeepWrap"), - (ContainerBuildNotSupported(), "ContainerBuildNotSupported"), - ( - UnsupportedBuilderLibraryVersionError(container_name="name", error_msg="msg"), - "UnsupportedBuilderLibraryVersionError", - ), - ] - ) - @patch("samcli.commands.build.build_context.BuildContext") - @patch("samcli.lib.build.app_builder.ApplicationBuilder") - def test_must_catch_known_exceptions(self, exception, wrapped_exception, ApplicationBuilderMock, BuildContextMock): - - ctx_mock = Mock() - BuildContextMock.return_value.__enter__ = Mock() - BuildContextMock.return_value.__enter__.return_value = ctx_mock - builder_mock = ApplicationBuilderMock.return_value = Mock() - - builder_mock.build.side_effect = exception - - with self.assertRaises(UserException) as ctx: - do_cli( - ctx_mock, - "function_identifier", - "template", - "base_dir", - "build_dir", - "cache_dir", - "clean", - "use_container", - "cached", - "parallel", - "manifest_path", - "docker_network", - "skip_pull", - "parameteroverrides", - "mode", - (""), - "container_env_var_file", - (), - ) - - self.assertEqual(str(ctx.exception), str(exception)) - self.assertEqual(wrapped_exception, ctx.exception.wrapped_from) - - @patch("samcli.commands.build.build_context.BuildContext") - @patch("samcli.lib.build.app_builder.ApplicationBuilder") - def test_must_catch_function_not_found_exception(self, ApplicationBuilderMock, BuildContextMock): - ctx_mock = Mock() - BuildContextMock.return_value.__enter__ = Mock() - BuildContextMock.return_value.__enter__.return_value = ctx_mock - ApplicationBuilderMock.side_effect = FunctionNotFound("Function Not Found") - - with self.assertRaises(UserException) as ctx: - do_cli( - ctx_mock, - "function_identifier", - "template", - "base_dir", - "build_dir", - "cache_dir", - "clean", - "use_container", - "cached", - "parallel", - "manifest_path", - "docker_network", - "skip_pull", - "parameteroverrides", - "mode", - (""), - "container_env_var_file", - (), - ) - - self.assertEqual(str(ctx.exception), "Function Not Found") + ctx_mock.run.assert_called_with() + self.assertEqual(ctx_mock.run.call_count, 1) class TestGetModeValueFromEnvvar(TestCase): diff --git a/tests/unit/commands/deploy/test_command.py b/tests/unit/commands/deploy/test_command.py index 6a9239e228..cf620a844a 100644 --- a/tests/unit/commands/deploy/test_command.py +++ b/tests/unit/commands/deploy/test_command.py @@ -47,6 +47,7 @@ def setUp(self): self.config_env = "mock-default-env" self.config_file = "mock-default-filename" self.signing_profiles = None + self.use_changeset = True self.resolve_image_repos = False self.disable_rollback = False MOCK_SAM_CONFIG.reset_mock() @@ -123,6 +124,7 @@ def test_all_args(self, mock_deploy_context, mock_deploy_click, mock_package_con profile=self.profile, confirm_changeset=self.confirm_changeset, signing_profiles=self.signing_profiles, + use_changeset=self.use_changeset, disable_rollback=self.disable_rollback, ) @@ -328,6 +330,7 @@ def test_all_args_guided( profile=self.profile, confirm_changeset=True, signing_profiles=self.signing_profiles, + use_changeset=self.use_changeset, disable_rollback=True, ) @@ -474,6 +477,7 @@ def test_all_args_guided_no_save_echo_param_to_config( profile=self.profile, confirm_changeset=True, signing_profiles=self.signing_profiles, + use_changeset=self.use_changeset, disable_rollback=True, ) @@ -624,6 +628,7 @@ def test_all_args_guided_no_params_save_config( profile=self.profile, confirm_changeset=True, signing_profiles=self.signing_profiles, + use_changeset=self.use_changeset, disable_rollback=True, ) @@ -759,6 +764,7 @@ def test_all_args_guided_no_params_no_save_config( profile=self.profile, confirm_changeset=True, signing_profiles=self.signing_profiles, + use_changeset=self.use_changeset, disable_rollback=self.disable_rollback, ) @@ -831,6 +837,7 @@ def test_all_args_resolve_s3( profile=self.profile, confirm_changeset=self.confirm_changeset, signing_profiles=self.signing_profiles, + use_changeset=self.use_changeset, disable_rollback=self.disable_rollback, ) @@ -941,6 +948,7 @@ def test_all_args_resolve_image_repos( profile=self.profile, confirm_changeset=self.confirm_changeset, signing_profiles=self.signing_profiles, + use_changeset=True, disable_rollback=self.disable_rollback, ) diff --git a/tests/unit/commands/deploy/test_deploy_context.py b/tests/unit/commands/deploy/test_deploy_context.py index e659c3fc57..a4bb034b1d 100644 --- a/tests/unit/commands/deploy/test_deploy_context.py +++ b/tests/unit/commands/deploy/test_deploy_context.py @@ -31,6 +31,7 @@ def setUp(self): profile=None, confirm_changeset=False, signing_profiles=None, + use_changeset=True, disable_rollback=False, ) @@ -153,3 +154,63 @@ def test_template_valid_execute_changeset_with_parameters( patched_get_buildable_stacks.assert_called_once_with( ANY, parameter_overrides={"a": "b"}, global_parameter_overrides={"AWS::Region": "any-aws-region"} ) + + @patch("boto3.Session") + @patch("samcli.commands.deploy.deploy_context.auth_per_resource") + @patch("samcli.commands.deploy.deploy_context.SamLocalStackProvider.get_stacks") + @patch.object(Deployer, "sync", MagicMock()) + def test_sync(self, patched_get_buildable_stacks, patched_auth_required, patched_boto): + sync_context = DeployContext( + template_file="template-file", + stack_name="stack-name", + s3_bucket="s3-bucket", + image_repository="image-repo", + image_repositories=None, + force_upload=True, + no_progressbar=False, + s3_prefix="s3-prefix", + kms_key_id="kms-key-id", + parameter_overrides={"a": "b"}, + capabilities="CAPABILITY_IAM", + no_execute_changeset=False, + role_arn="role-arn", + notification_arns=[], + fail_on_empty_changeset=False, + tags={"a": "b"}, + region=None, + profile=None, + confirm_changeset=False, + signing_profiles=None, + use_changeset=False, + disable_rollback=False, + ) + patched_get_buildable_stacks.return_value = (Mock(), []) + patched_auth_required.return_value = [("HelloWorldFunction", False)] + with tempfile.NamedTemporaryFile(delete=False) as template_file: + template_file.write(b'{"Parameters": {"a":"b","c":"d"}}') + template_file.flush() + sync_context.template_file = template_file.name + sync_context.run() + + self.assertEqual(sync_context.deployer.sync.call_count, 1) + print(sync_context.deployer.sync.call_args[1]) + self.assertEqual( + sync_context.deployer.sync.call_args[1]["stack_name"], + "stack-name", + ) + self.assertEqual( + sync_context.deployer.sync.call_args[1]["capabilities"], + "CAPABILITY_IAM", + ) + self.assertEqual( + sync_context.deployer.sync.call_args[1]["cfn_template"], + '{"Parameters": {"a":"b","c":"d"}}', + ) + self.assertEqual( + sync_context.deployer.sync.call_args[1]["notification_arns"], + [], + ) + self.assertEqual( + sync_context.deployer.sync.call_args[1]["role_arn"], + "role-arn", + ) diff --git a/tests/unit/commands/local/lib/test_provider.py b/tests/unit/commands/local/lib/test_provider.py index 6888768aca..97567f00ba 100644 --- a/tests/unit/commands/local/lib/test_provider.py +++ b/tests/unit/commands/local/lib/test_provider.py @@ -1,12 +1,22 @@ import os from unittest import TestCase -from unittest.mock import Mock +from unittest.mock import MagicMock, Mock, patch from parameterized import parameterized from samcli.lib.utils.architecture import X86_64, ARM64 -from samcli.lib.providers.provider import LayerVersion, Stack, _get_build_dir, Function +from samcli.lib.providers.provider import ( + LayerVersion, + ResourceIdentifier, + Stack, + _get_build_dir, + get_all_resource_ids, + get_resource_by_id, + get_resource_ids_by_type, + get_unique_resource_ids, + Function, +) from samcli.commands.local.cli_common.user_exceptions import ( InvalidLayerVersionArn, UnsupportedIntrinsic, @@ -172,3 +182,256 @@ def test_no_layer_build_architecture_returned(self): [ARM64], ) self.assertEqual(layer_version.build_architecture, X86_64) + + +class TestResourceIdentifier(TestCase): + @parameterized.expand( + [ + ("Function1", "", "Function1"), + ("NestedStack1/Function1", "NestedStack1", "Function1"), + ("NestedStack1/NestedNestedStack2/Function1", "NestedStack1/NestedNestedStack2", "Function1"), + ("", "", ""), + ] + ) + def test_parser(self, resource_identifier_string, stack_path, logical_id): + resource_identifier = ResourceIdentifier(resource_identifier_string) + self.assertEqual(resource_identifier.stack_path, stack_path) + self.assertEqual(resource_identifier.logical_id, logical_id) + + @parameterized.expand( + [ + ("Function1", "Function1", True), + ("NestedStack1/Function1", "NestedStack1/Function1", True), + ("NestedStack1/NestedNestedStack2/Function1", "NestedStack1/NestedNestedStack2/Function2", False), + ("NestedStack1/NestedNestedStack3/Function1", "NestedStack1/NestedNestedStack2/Function1", False), + ("", "", True), + ] + ) + def test_equal(self, resource_identifier_string_1, resource_identifier_string_2, equal): + resource_identifier_1 = ResourceIdentifier(resource_identifier_string_1) + resource_identifier_2 = ResourceIdentifier(resource_identifier_string_2) + self.assertEqual(resource_identifier_1 == resource_identifier_2, equal) + + @parameterized.expand( + [ + ("Function1"), + ("NestedStack1/Function1"), + ("NestedStack1/NestedNestedStack2/Function1"), + ] + ) + def test_hash(self, resource_identifier_string): + resource_identifier_1 = ResourceIdentifier(resource_identifier_string) + resource_identifier_2 = ResourceIdentifier(resource_identifier_string) + self.assertEqual(hash(resource_identifier_1), hash(resource_identifier_2)) + + @parameterized.expand( + [ + ("Function1"), + ("NestedStack1/Function1"), + ("NestedStack1/NestedNestedStack2/Function1"), + (""), + ] + ) + def test_str(self, resource_identifier_string): + resource_identifier = ResourceIdentifier(resource_identifier_string) + self.assertEqual(str(resource_identifier), resource_identifier_string) + + +class TestGetResourceByID(TestCase): + def setUp(self) -> None: + super().setUp() + self.root_stack = MagicMock() + self.root_stack.stack_path = "" + self.root_stack.resources = {"Function1": "Body1"} + + self.nested_stack = MagicMock() + self.nested_stack.stack_path = "NestedStack1" + self.nested_stack.resources = {"Function1": "Body2"} + + self.nested_nested_stack = MagicMock() + self.nested_nested_stack.stack_path = "NestedStack1/NestedNestedStack1" + self.nested_nested_stack.resources = {"Function2": "Body3"} + + def test_get_resource_by_id_explicit_root( + self, + ): + + resource_identifier = MagicMock() + resource_identifier.stack_path = "" + resource_identifier.logical_id = "Function1" + + result = get_resource_by_id( + [self.root_stack, self.nested_stack, self.nested_nested_stack], resource_identifier, True + ) + self.assertEqual(result, self.root_stack.resources["Function1"]) + + def test_get_resource_by_id_explicit_nested( + self, + ): + + resource_identifier = MagicMock() + resource_identifier.stack_path = "NestedStack1" + resource_identifier.logical_id = "Function1" + + result = get_resource_by_id( + [self.root_stack, self.nested_stack, self.nested_nested_stack], resource_identifier, True + ) + self.assertEqual(result, self.nested_stack.resources["Function1"]) + + def test_get_resource_by_id_explicit_nested_nested( + self, + ): + + resource_identifier = MagicMock() + resource_identifier.stack_path = "NestedStack1/NestedNestedStack1" + resource_identifier.logical_id = "Function2" + + result = get_resource_by_id( + [self.root_stack, self.nested_stack, self.nested_nested_stack], resource_identifier, True + ) + self.assertEqual(result, self.nested_nested_stack.resources["Function2"]) + + def test_get_resource_by_id_implicit_root( + self, + ): + + resource_identifier = MagicMock() + resource_identifier.stack_path = "" + resource_identifier.logical_id = "Function1" + + result = get_resource_by_id( + [self.root_stack, self.nested_stack, self.nested_nested_stack], resource_identifier, False + ) + self.assertEqual(result, self.root_stack.resources["Function1"]) + + def test_get_resource_by_id_implicit_nested( + self, + ): + + resource_identifier = MagicMock() + resource_identifier.stack_path = "" + resource_identifier.logical_id = "Function2" + + result = get_resource_by_id( + [self.root_stack, self.nested_stack, self.nested_nested_stack], resource_identifier, False + ) + self.assertEqual(result, self.nested_nested_stack.resources["Function2"]) + + def test_get_resource_by_id_implicit_with_stack_path( + self, + ): + + resource_identifier = MagicMock() + resource_identifier.stack_path = "NestedStack1" + resource_identifier.logical_id = "Function1" + + result = get_resource_by_id( + [self.root_stack, self.nested_stack, self.nested_nested_stack], resource_identifier, False + ) + self.assertEqual(result, self.nested_stack.resources["Function1"]) + + def test_get_resource_by_id_not_found( + self, + ): + + resource_identifier = MagicMock() + resource_identifier.logical_id = "Function3" + + result = get_resource_by_id( + [self.root_stack, self.nested_stack, self.nested_nested_stack], resource_identifier, False + ) + self.assertEqual(result, None) + + +class TestGetResourceIDsByType(TestCase): + def setUp(self) -> None: + super().setUp() + self.root_stack = MagicMock() + self.root_stack.stack_path = "" + self.root_stack.resources = {"Function1": {"Type": "TypeA"}} + + self.nested_stack = MagicMock() + self.nested_stack.stack_path = "NestedStack1" + self.nested_stack.resources = {"Function1": {"Type": "TypeA"}} + + self.nested_nested_stack = MagicMock() + self.nested_nested_stack.stack_path = "NestedStack1/NestedNestedStack1" + self.nested_nested_stack.resources = {"Function2": {"Type": "TypeB"}} + + def test_get_resource_ids_by_type_single_nested( + self, + ): + result = get_resource_ids_by_type([self.root_stack, self.nested_stack, self.nested_nested_stack], "TypeB") + self.assertEqual(result, [ResourceIdentifier("NestedStack1/NestedNestedStack1/Function2")]) + + def test_get_resource_ids_by_type_multiple_nested( + self, + ): + result = get_resource_ids_by_type([self.root_stack, self.nested_stack, self.nested_nested_stack], "TypeA") + self.assertEqual(result, [ResourceIdentifier("Function1"), ResourceIdentifier("NestedStack1/Function1")]) + + +class TestGetAllResourceIDs(TestCase): + def setUp(self) -> None: + super().setUp() + self.root_stack = MagicMock() + self.root_stack.stack_path = "" + self.root_stack.resources = {"Function1": {"Type": "TypeA"}} + + self.nested_stack = MagicMock() + self.nested_stack.stack_path = "NestedStack1" + self.nested_stack.resources = {"Function1": {"Type": "TypeA"}} + + self.nested_nested_stack = MagicMock() + self.nested_nested_stack.stack_path = "NestedStack1/NestedNestedStack1" + self.nested_nested_stack.resources = {"Function2": {"Type": "TypeB"}} + + def test_get_all_resource_ids( + self, + ): + result = get_all_resource_ids([self.root_stack, self.nested_stack, self.nested_nested_stack]) + self.assertEqual( + result, + [ + ResourceIdentifier("Function1"), + ResourceIdentifier("NestedStack1/Function1"), + ResourceIdentifier("NestedStack1/NestedNestedStack1/Function2"), + ], + ) + + +class TestGetUniqueResourceIDs(TestCase): + def setUp(self) -> None: + super().setUp() + self.stacks = MagicMock() + + @patch("samcli.lib.providers.provider.get_resource_ids_by_type") + def test_only_resource_ids(self, get_resource_ids_by_type_mock): + resource_ids = ["Function1", "Function2"] + resource_types = [] + get_resource_ids_by_type_mock.return_value = {} + result = get_unique_resource_ids(self.stacks, resource_ids, resource_types) + get_resource_ids_by_type_mock.assert_not_called() + self.assertEqual(result, {ResourceIdentifier("Function1"), ResourceIdentifier("Function2")}) + + @patch("samcli.lib.providers.provider.get_resource_ids_by_type") + def test_only_resource_types(self, get_resource_ids_by_type_mock): + resource_ids = [] + resource_types = ["Type1", "Type2"] + get_resource_ids_by_type_mock.return_value = {ResourceIdentifier("Function1"), ResourceIdentifier("Function2")} + result = get_unique_resource_ids(self.stacks, resource_ids, resource_types) + get_resource_ids_by_type_mock.assert_any_call(self.stacks, "Type1") + get_resource_ids_by_type_mock.assert_any_call(self.stacks, "Type2") + self.assertEqual(result, {ResourceIdentifier("Function1"), ResourceIdentifier("Function2")}) + + @patch("samcli.lib.providers.provider.get_resource_ids_by_type") + def test_duplicates(self, get_resource_ids_by_type_mock): + resource_ids = ["Function1", "Function2"] + resource_types = ["Type1", "Type2"] + get_resource_ids_by_type_mock.return_value = {ResourceIdentifier("Function2"), ResourceIdentifier("Function3")} + result = get_unique_resource_ids(self.stacks, resource_ids, resource_types) + get_resource_ids_by_type_mock.assert_any_call(self.stacks, "Type1") + get_resource_ids_by_type_mock.assert_any_call(self.stacks, "Type2") + self.assertEqual( + result, {ResourceIdentifier("Function1"), ResourceIdentifier("Function2"), ResourceIdentifier("Function3")} + ) diff --git a/tests/unit/commands/local/lib/test_stack_provider.py b/tests/unit/commands/local/lib/test_stack_provider.py index 71bcec648c..565f6fb50d 100644 --- a/tests/unit/commands/local/lib/test_stack_provider.py +++ b/tests/unit/commands/local/lib/test_stack_provider.py @@ -6,7 +6,7 @@ from parameterized import parameterized -from samcli.commands._utils.resources import AWS_SERVERLESS_APPLICATION, AWS_CLOUDFORMATION_STACK +from samcli.lib.utils.resources import AWS_SERVERLESS_APPLICATION, AWS_CLOUDFORMATION_STACK from samcli.lib.providers.provider import Stack from samcli.lib.providers.sam_stack_provider import SamLocalStackProvider diff --git a/tests/unit/commands/logs/test_command.py b/tests/unit/commands/logs/test_command.py index 3a48600ae0..bdf394ca33 100644 --- a/tests/unit/commands/logs/test_command.py +++ b/tests/unit/commands/logs/test_command.py @@ -1,9 +1,12 @@ from unittest import TestCase -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, call, ANY + +from parameterized import parameterized from samcli.commands.logs.command import do_cli +@patch("samcli.commands._utils.experimental.is_experimental_enabled") class TestLogsCliCommand(TestCase): def setUp(self): @@ -12,47 +15,112 @@ def setUp(self): self.filter_pattern = "filter" self.start_time = "start" self.end_time = "end" + self.output_dir = "output_dir" + self.region = "region" + self.profile = "profile" + + @parameterized.expand( + [ + ( + True, + False, + [], + ), + ( + False, + False, + [], + ), + ( + True, + False, + ["cw_log_group"], + ), + ( + False, + False, + ["cw_log_group", "cw_log_group2"], + ), + ] + ) + @patch("samcli.commands.logs.puller_factory.generate_puller") + @patch("samcli.commands.logs.logs_context.ResourcePhysicalIdResolver") + @patch("samcli.commands.logs.logs_context.parse_time") + @patch("samcli.lib.utils.boto_utils.get_boto_client_provider_with_config") + @patch("samcli.lib.utils.boto_utils.get_boto_resource_provider_with_config") + def test_logs_command( + self, + tailing, + include_tracing, + cw_log_group, + patched_boto_resource_provider, + patched_boto_client_provider, + patched_parse_time, + patched_resource_physical_id_resolver, + patched_generate_puller, + patched_is_experimental_enabled, + ): + mocked_start_time = Mock() + mocked_end_time = Mock() + patched_parse_time.side_effect = [mocked_start_time, mocked_end_time] + + mocked_resource_physical_id_resolver = Mock() + mocked_resource_information = Mock() + mocked_resource_physical_id_resolver.get_resource_information.return_value = mocked_resource_information + patched_resource_physical_id_resolver.return_value = mocked_resource_physical_id_resolver - @patch("samcli.commands.logs.logs_context.LogsCommandContext") - def test_without_tail(self, logs_command_context_mock): - tailing = False + mocked_puller = Mock() + patched_generate_puller.return_value = mocked_puller - context_mock = Mock() - logs_command_context_mock.return_value.__enter__.return_value = context_mock + mocked_client_provider = Mock() + patched_boto_client_provider.return_value = mocked_client_provider - do_cli(self.function_name, self.stack_name, self.filter_pattern, tailing, self.start_time, self.end_time) + mocked_resource_provider = Mock() + patched_boto_resource_provider.return_value = mocked_resource_provider - logs_command_context_mock.assert_called_with( + do_cli( self.function_name, - stack_name=self.stack_name, - filter_pattern=self.filter_pattern, - start_time=self.start_time, - end_time=self.end_time, + self.stack_name, + self.filter_pattern, + tailing, + include_tracing, + self.start_time, + self.end_time, + cw_log_group, + self.output_dir, + self.region, + self.profile, ) - context_mock.fetcher.load_time_period.assert_called_with( - filter_pattern=context_mock.filter_pattern, - start_time=context_mock.start_time, - end_time=context_mock.end_time, + patched_parse_time.assert_has_calls( + [ + call(self.start_time, "start-time"), + call(self.end_time, "end-time"), + ] ) - @patch("samcli.commands.logs.logs_context.LogsCommandContext") - def test_with_tailing(self, logs_command_context_mock): - tailing = True + patched_boto_client_provider.assert_called_with(region=self.region, profile=self.profile) + patched_boto_resource_provider.assert_called_with(region=self.region, profile=self.profile) - context_mock = Mock() - logs_command_context_mock.return_value.__enter__.return_value = context_mock + patched_resource_physical_id_resolver.assert_called_with( + mocked_resource_provider, self.stack_name, self.function_name + ) - do_cli(self.function_name, self.stack_name, self.filter_pattern, tailing, self.start_time, self.end_time) + fetch_param = not bool(len(cw_log_group)) + mocked_resource_physical_id_resolver.assert_has_calls([call.get_resource_information(fetch_param)]) - logs_command_context_mock.assert_called_with( - self.function_name, - stack_name=self.stack_name, - filter_pattern=self.filter_pattern, - start_time=self.start_time, - end_time=self.end_time, + patched_generate_puller.assert_called_with( + mocked_client_provider, + mocked_resource_information, + self.filter_pattern, + cw_log_group, + self.output_dir, + False, ) - context_mock.fetcher.tail.assert_called_with( - filter_pattern=context_mock.filter_pattern, start_time=context_mock.start_time - ) + if tailing: + mocked_puller.assert_has_calls([call.tail(mocked_start_time, self.filter_pattern)]) + else: + mocked_puller.assert_has_calls( + [call.load_time_period(mocked_start_time, mocked_end_time, self.filter_pattern)] + ) diff --git a/tests/unit/commands/logs/test_console_consumers.py b/tests/unit/commands/logs/test_console_consumers.py index ab824ca769..bfb4a6ba13 100644 --- a/tests/unit/commands/logs/test_console_consumers.py +++ b/tests/unit/commands/logs/test_console_consumers.py @@ -1,15 +1,31 @@ from unittest import TestCase from unittest.mock import patch, Mock +from parameterized import parameterized + from samcli.commands.logs.console_consumers import CWConsoleEventConsumer class TestCWConsoleEventConsumer(TestCase): - def setUp(self): - self.consumer = CWConsoleEventConsumer() + @parameterized.expand( + [ + (True,), + (False,), + ] + ) + @patch("samcli.commands.logs.console_consumers.click") + def test_consumer_with_event(self, add_newline, patched_click): + consumer = CWConsoleEventConsumer(add_newline) + event = Mock() + consumer.consume(event) + + expected_new_line_param = add_newline if add_newline is not None else True + patched_click.echo.assert_called_with(event.message, nl=expected_new_line_param) @patch("samcli.commands.logs.console_consumers.click") - def test_consume_with_event(self, patched_click): + def test_default_consumer_with_event(self, patched_click): + consumer = CWConsoleEventConsumer() event = Mock() - self.consumer.consume(event) + consumer.consume(event) + patched_click.echo.assert_called_with(event.message, nl=False) diff --git a/tests/unit/commands/logs/test_logs_context.py b/tests/unit/commands/logs/test_logs_context.py index abcd792b27..050ae9ef91 100644 --- a/tests/unit/commands/logs/test_logs_context.py +++ b/tests/unit/commands/logs/test_logs_context.py @@ -1,11 +1,14 @@ -from unittest import TestCase -from unittest.mock import Mock, patch, ANY - -import botocore.session -from botocore.stub import Stubber +from unittest import TestCase, mock +from unittest.mock import Mock, patch from samcli.commands.exceptions import UserException -from samcli.commands.logs.logs_context import LogsCommandContext +from samcli.commands.logs.logs_context import parse_time, ResourcePhysicalIdResolver +from samcli.lib.utils.cloudformation import CloudFormationResourceSummary + +AWS_SOME_RESOURCE = "AWS::Some::Resource" +AWS_LAMBDA_FUNCTION = "AWS::Lambda::Function" +AWS_APIGATEWAY_RESTAPI = "AWS::ApiGateway::RestApi" +AWS_APIGATEWAY_HTTPAPI = "AWS::ApiGatewayV2::Api" class TestLogsCommandContext(TestCase): @@ -17,214 +20,110 @@ def setUp(self): self.end_time = "end" self.output_file = "somefile" - self.context = LogsCommandContext( - self.function_name, - stack_name=self.stack_name, - filter_pattern=self.filter_pattern, - start_time=self.start_time, - end_time=self.end_time, - output_file=self.output_file, - ) - - def test_basic_properties(self): - self.assertEqual(self.context.filter_pattern, self.filter_pattern) - self.assertIsNone(self.context.output_file_handle) # before setting context handle will be null - - @patch("samcli.commands.logs.logs_context.Colored") - def test_colored_property(self, ColoredMock): - ColoredMock.return_value = Mock() - - self.assertEqual(self.context.colored, ColoredMock.return_value) - ColoredMock.assert_called_with(colorize=False) - - @patch("samcli.commands.logs.logs_context.Colored") - def test_colored_property_without_output_file(self, ColoredMock): - ColoredMock.return_value = Mock() - - # No output file. It means we are printing to Terminal. Hence set the color - ctx = LogsCommandContext( - self.function_name, - stack_name=self.stack_name, - filter_pattern=self.filter_pattern, - start_time=self.start_time, - end_time=self.end_time, - output_file=None, - ) - - self.assertEqual(ctx.colored, ColoredMock.return_value) - ColoredMock.assert_called_with(colorize=True) # Must enable colors - - @patch("samcli.commands.logs.logs_context.LogGroupProvider") - @patch.object(LogsCommandContext, "_get_resource_id_from_stack") - def test_log_group_name_property_with_stack_name(self, get_resource_id_mock, LogGroupProviderMock): - logical_id = "someid" - group = "groupname" - - LogGroupProviderMock.for_lambda_function.return_value = group - get_resource_id_mock.return_value = logical_id - - self.assertEqual(self.context.log_group_name, group) - - LogGroupProviderMock.for_lambda_function.assert_called_with(logical_id) - get_resource_id_mock.assert_called_with(ANY, self.stack_name, self.function_name) - - @patch("samcli.commands.logs.logs_context.LogGroupProvider") - @patch.object(LogsCommandContext, "_get_resource_id_from_stack") - def test_log_group_name_property_without_stack_name(self, get_resource_id_mock, LogGroupProviderMock): - group = "groupname" - - LogGroupProviderMock.for_lambda_function.return_value = group - - ctx = LogsCommandContext( - self.function_name, - stack_name=None, # No Stack Name - filter_pattern=self.filter_pattern, - start_time=self.start_time, - end_time=self.end_time, - output_file=self.output_file, - ) - - self.assertEqual(ctx.log_group_name, group) - - LogGroupProviderMock.for_lambda_function.assert_called_with(self.function_name) - get_resource_id_mock.assert_not_called() - - def test_start_time_property(self): - self.context._parse_time = Mock() - self.context._parse_time.return_value = "foo" - - self.assertEqual(self.context.start_time, "foo") - - def test_end_time_property(self): - self.context._parse_time = Mock() - self.context._parse_time.return_value = "foo" - - self.assertEqual(self.context.end_time, "foo") - @patch("samcli.commands.logs.logs_context.parse_date") @patch("samcli.commands.logs.logs_context.to_utc") def test_parse_time(self, to_utc_mock, parse_date_mock): - input = "some time" + given_input = "some time" parsed_result = "parsed" expected = "bar" parse_date_mock.return_value = parsed_result to_utc_mock.return_value = expected - actual = LogsCommandContext._parse_time(input, "some prop") + actual = parse_time(given_input, "some prop") self.assertEqual(actual, expected) - parse_date_mock.assert_called_with(input) + parse_date_mock.assert_called_with(given_input) to_utc_mock.assert_called_with(parsed_result) @patch("samcli.commands.logs.logs_context.parse_date") def test_parse_time_raises_exception(self, parse_date_mock): - input = "some time" + given_input = "some time" parsed_result = None parse_date_mock.return_value = parsed_result with self.assertRaises(UserException) as ctx: - LogsCommandContext._parse_time(input, "some prop") + parse_time(given_input, "some prop") self.assertEqual(str(ctx.exception), "Unable to parse the time provided by 'some prop'") def test_parse_time_empty_time(self): - result = LogsCommandContext._parse_time(None, "some prop") + result = parse_time(None, "some prop") self.assertIsNone(result) - @patch("samcli.commands.logs.logs_context.open") - def test_setup_output_file(self, open_mock): - - open_mock.return_value = "handle" - result = LogsCommandContext._setup_output_file(self.output_file) - - self.assertEqual(result, "handle") - open_mock.assert_called_with(self.output_file, "wb") - - def test_setup_output_file_without_file(self): - self.assertIsNone(LogsCommandContext._setup_output_file(None)) - - @patch.object(LogsCommandContext, "_setup_output_file") - def test_context_manager_with_output_file(self, setup_output_file_mock): - handle = Mock() - setup_output_file_mock.return_value = handle - - with LogsCommandContext( - self.function_name, - stack_name=self.stack_name, - filter_pattern=self.filter_pattern, - start_time=self.start_time, - end_time=self.end_time, - output_file=self.output_file, - ) as context: - self.assertEqual(context._output_file_handle, handle) - - # Context should be reset - self.assertIsNone(self.context._output_file_handle) - handle.close.assert_called_with() - setup_output_file_mock.assert_called_with(self.output_file) - - @patch.object(LogsCommandContext, "_setup_output_file") - def test_context_manager_no_output_file(self, setup_output_file_mock): - setup_output_file_mock.return_value = None - - with LogsCommandContext( - self.function_name, - stack_name=self.stack_name, - filter_pattern=self.filter_pattern, - start_time=self.start_time, - end_time=self.end_time, - output_file=None, - ) as context: - self.assertEqual(context._output_file_handle, None) - - # Context should be reset - setup_output_file_mock.assert_called_with(None) - - -class TestLogsCommandContext_get_resource_id_from_stack(TestCase): - def setUp(self): - - self.real_client = botocore.session.get_session().create_client("cloudformation", region_name="us-east-1") - self.cfn_client_stubber = Stubber(self.real_client) - - self.logical_id = "name" - self.stack_name = "stackname" - self.physical_id = "myid" - - def test_must_get_from_cfn(self): - - expected_params = {"StackName": self.stack_name, "LogicalResourceId": self.logical_id} - - mock_response = { - "StackResourceDetail": { - "PhysicalResourceId": self.physical_id, - "LogicalResourceId": self.logical_id, - "ResourceType": "AWS::Lambda::Function", - "ResourceStatus": "UPDATE_COMPLETE", - "LastUpdatedTimestamp": "2017-07-28T23:34:13.435Z", - } - } - - self.cfn_client_stubber.add_response("describe_stack_resource", mock_response, expected_params) - - with self.cfn_client_stubber: - result = LogsCommandContext._get_resource_id_from_stack(self.real_client, self.stack_name, self.logical_id) - - self.assertEqual(result, self.physical_id) - - def test_must_handle_resource_not_found(self): - errmsg = "Something went wrong" - errcode = "SomeException" - - self.cfn_client_stubber.add_client_error( - "describe_stack_resource", service_error_code=errcode, service_message=errmsg - ) - expected_error_msg = "An error occurred ({}) when calling the DescribeStackResource operation: {}".format( - errcode, errmsg - ) - - with self.cfn_client_stubber: - with self.assertRaises(UserException) as context: - LogsCommandContext._get_resource_id_from_stack(self.real_client, self.stack_name, self.logical_id) - self.assertEqual(expected_error_msg, str(context.exception)) +class TestResourcePhysicalIdResolver(TestCase): + def test_get_resource_information_with_resources(self): + resource_physical_id_resolver = ResourcePhysicalIdResolver(Mock(), "stack_name", ["resource_name"]) + with mock.patch( + "samcli.commands.logs.logs_context.ResourcePhysicalIdResolver._fetch_resources_from_stack" + ) as mocked_fetch: + expected_return = Mock() + mocked_fetch.return_value = expected_return + + actual_return = resource_physical_id_resolver.get_resource_information(False) + + mocked_fetch.assert_called_once() + self.assertEqual(actual_return, expected_return) + + def test_get_resource_information_of_all_stack(self): + resource_physical_id_resolver = ResourcePhysicalIdResolver(Mock(), "stack_name", []) + with mock.patch( + "samcli.commands.logs.logs_context.ResourcePhysicalIdResolver._fetch_resources_from_stack" + ) as mocked_fetch: + expected_return = Mock() + mocked_fetch.return_value = expected_return + + actual_return = resource_physical_id_resolver.get_resource_information(True) + + mocked_fetch.assert_called_once() + self.assertEqual(actual_return, expected_return) + + def test_get_no_resource_information(self): + resource_physical_id_resolver = ResourcePhysicalIdResolver(Mock(), "stack_name", None) + actual_return = resource_physical_id_resolver.get_resource_information(False) + self.assertEqual(actual_return, []) + + @patch("samcli.commands.logs.logs_context.get_resource_summaries") + def test_fetch_all_resources(self, patched_get_resources): + resource_physical_id_resolver = ResourcePhysicalIdResolver(Mock(), "stack_name", []) + mocked_return_value = [ + CloudFormationResourceSummary(AWS_LAMBDA_FUNCTION, "logical_id_1", "physical_id_1"), + CloudFormationResourceSummary(AWS_LAMBDA_FUNCTION, "logical_id_2", "physical_id_2"), + CloudFormationResourceSummary(AWS_APIGATEWAY_RESTAPI, "logical_id_3", "physical_id_3"), + CloudFormationResourceSummary(AWS_APIGATEWAY_HTTPAPI, "logical_id_4", "physical_id_4"), + ] + patched_get_resources.return_value = mocked_return_value + + actual_result = resource_physical_id_resolver._fetch_resources_from_stack() + self.assertEqual(len(actual_result), 4) + + expected_results = [ + item + for item in mocked_return_value + if item.resource_type in ResourcePhysicalIdResolver.DEFAULT_SUPPORTED_RESOURCES + ] + self.assertEqual(expected_results, actual_result) + + @patch("samcli.commands.logs.logs_context.get_resource_summaries") + def test_fetch_given_resources(self, patched_get_resources): + given_resources = ["logical_id_1", "logical_id_2", "logical_id_3", "logical_id_5", "logical_id_6"] + resource_physical_id_resolver = ResourcePhysicalIdResolver(Mock(), "stack_name", given_resources) + mocked_return_value = [ + CloudFormationResourceSummary(AWS_LAMBDA_FUNCTION, "logical_id_1", "physical_id_1"), + CloudFormationResourceSummary(AWS_LAMBDA_FUNCTION, "logical_id_2", "physical_id_2"), + CloudFormationResourceSummary(AWS_LAMBDA_FUNCTION, "logical_id_3", "physical_id_3"), + CloudFormationResourceSummary(AWS_APIGATEWAY_RESTAPI, "logical_id_4", "physical_id_4"), + CloudFormationResourceSummary(AWS_APIGATEWAY_HTTPAPI, "logical_id_5", "physical_id_5"), + ] + patched_get_resources.return_value = mocked_return_value + + actual_result = resource_physical_id_resolver._fetch_resources_from_stack(set(given_resources)) + self.assertEqual(len(actual_result), 4) + + expected_results = [ + item + for item in mocked_return_value + if item.resource_type in ResourcePhysicalIdResolver.DEFAULT_SUPPORTED_RESOURCES + and item.logical_resource_id in given_resources + ] + self.assertEqual(expected_results, actual_result) diff --git a/tests/unit/commands/logs/test_puller_factory.py b/tests/unit/commands/logs/test_puller_factory.py new file mode 100644 index 0000000000..bf4f6dd143 --- /dev/null +++ b/tests/unit/commands/logs/test_puller_factory.py @@ -0,0 +1,258 @@ +from unittest import TestCase +from unittest.mock import Mock, patch, call, ANY + +from parameterized import parameterized + +from samcli.lib.utils.resources import AWS_LAMBDA_FUNCTION +from samcli.commands.logs.puller_factory import ( + generate_puller, + generate_unformatted_consumer, + generate_console_consumer, + NoPullerGeneratedException, + generate_consumer, +) + + +class TestPullerFactory(TestCase): + @parameterized.expand( + [ + (None, None, False), + ("filter_pattern", None, False), + ("filter_pattern", ["cw_log_groups"], False), + ("filter_pattern", ["cw_log_groups"], True), + (None, ["cw_log_groups"], True), + (None, None, True), + ] + ) + @patch("samcli.commands.logs.puller_factory.generate_console_consumer") + @patch("samcli.commands.logs.puller_factory.generate_unformatted_consumer") + @patch("samcli.commands.logs.puller_factory.CWLogPuller") + @patch("samcli.commands.logs.puller_factory.generate_trace_puller") + @patch("samcli.commands.logs.puller_factory.ObservabilityCombinedPuller") + def test_generate_puller( + self, + param_filter_pattern, + param_cw_log_groups, + param_unformatted, + patched_combined_puller, + patched_xray_puller, + patched_cw_log_puller, + patched_unformatted_consumer, + patched_console_consumer, + ): + mock_logs_client = Mock() + mock_xray_client = Mock() + + mock_client_provider = lambda client_name: mock_logs_client if client_name == "logs" else mock_xray_client + + mock_resource_info_list = [ + Mock(resource_type=AWS_LAMBDA_FUNCTION), + Mock(resource_type=AWS_LAMBDA_FUNCTION), + Mock(resource_type=AWS_LAMBDA_FUNCTION), + ] + + mocked_resource_consumers = [Mock() for _ in mock_resource_info_list] + mocked_cw_specific_consumers = [Mock() for _ in (param_cw_log_groups or [])] + mocked_consumers = mocked_resource_consumers + mocked_cw_specific_consumers + + # depending on the output_dir param patch file consumer or console consumer + if param_unformatted: + patched_unformatted_consumer.side_effect = mocked_consumers + else: + patched_console_consumer.side_effect = mocked_consumers + + mocked_xray_puller = Mock() + patched_xray_puller.return_value = mocked_xray_puller + mocked_pullers = [Mock() for _ in mocked_consumers] + mocked_pullers.append(mocked_xray_puller) # add a mock puller for xray puller + patched_cw_log_puller.side_effect = mocked_pullers + + mocked_combined_puller = Mock() + + patched_combined_puller.return_value = mocked_combined_puller + + puller = generate_puller( + mock_client_provider, + mock_resource_info_list, + param_filter_pattern, + param_cw_log_groups, + param_unformatted, + True, + ) + + self.assertEqual(puller, mocked_combined_puller) + + patched_xray_puller.assert_called_once_with(mock_xray_client, param_unformatted) + + patched_cw_log_puller.assert_has_calls( + [call(mock_logs_client, consumer, ANY, ANY) for consumer in mocked_resource_consumers] + ) + + patched_cw_log_puller.assert_has_calls( + [call(mock_logs_client, consumer, ANY) for consumer in mocked_cw_specific_consumers] + ) + + patched_combined_puller.assert_called_with(mocked_pullers) + + # depending on the output_dir param assert calls for file consumer or console consumer + if param_unformatted: + patched_unformatted_consumer.assert_has_calls([call() for _ in mocked_consumers]) + else: + patched_console_consumer.assert_has_calls([call(param_filter_pattern) for _ in mocked_consumers]) + + def test_puller_with_invalid_resource_type(self): + mock_logs_client = Mock() + mock_resource_information = Mock() + mock_resource_information.get_log_group_name.return_value = None + + with self.assertRaises(NoPullerGeneratedException): + generate_puller(mock_logs_client, [mock_resource_information]) + + @patch("samcli.commands.logs.puller_factory.generate_console_consumer") + @patch("samcli.commands.logs.puller_factory.CWLogPuller") + @patch("samcli.commands.logs.puller_factory.ObservabilityCombinedPuller") + def test_generate_puller_with_console_with_additional_cw_logs_groups( + self, patched_combined_puller, patched_cw_log_puller, patched_console_consumer + ): + mock_logs_client = Mock() + mock_logs_client_generator = lambda client: mock_logs_client + mock_cw_log_groups = [Mock(), Mock(), Mock()] + + mocked_consumers = [Mock() for _ in mock_cw_log_groups] + patched_console_consumer.side_effect = mocked_consumers + + mocked_pullers = [Mock() for _ in mock_cw_log_groups] + patched_cw_log_puller.side_effect = mocked_pullers + + mocked_combined_puller = Mock() + patched_combined_puller.return_value = mocked_combined_puller + + puller = generate_puller(mock_logs_client_generator, [], additional_cw_log_groups=mock_cw_log_groups) + + self.assertEqual(puller, mocked_combined_puller) + + patched_cw_log_puller.assert_has_calls([call(mock_logs_client, consumer, ANY) for consumer in mocked_consumers]) + + patched_combined_puller.assert_called_with(mocked_pullers) + + patched_console_consumer.assert_has_calls([call(None) for _ in mock_cw_log_groups]) + + @parameterized.expand( + [ + (False,), + (True,), + ] + ) + @patch("samcli.commands.logs.puller_factory.generate_unformatted_consumer") + @patch("samcli.commands.logs.puller_factory.generate_console_consumer") + def test_generate_consumer(self, param_unformatted, patched_console_consumer, patched_unformatted_consumer): + given_filter_pattern = Mock() + given_resource_name = Mock() + + given_console_consumer = Mock() + patched_console_consumer.return_value = given_console_consumer + given_file_consumer = Mock() + patched_unformatted_consumer.return_value = given_file_consumer + + actual_consumer = generate_consumer(given_filter_pattern, param_unformatted, given_resource_name) + + if param_unformatted: + patched_unformatted_consumer.assert_called_with() + self.assertEqual(actual_consumer, given_file_consumer) + else: + patched_console_consumer.assert_called_with(given_filter_pattern) + self.assertEqual(actual_consumer, given_console_consumer) + + @patch("samcli.commands.logs.puller_factory.ObservabilityEventConsumerDecorator") + @patch("samcli.commands.logs.puller_factory.CWLogEventJSONMapper") + @patch("samcli.commands.logs.puller_factory.CWConsoleEventConsumer") + def test_generate_unformatted_consumer( + self, + patched_event_consumer, + patched_json_formatter, + patched_decorated_consumer, + ): + expected_consumer = Mock() + patched_decorated_consumer.return_value = expected_consumer + + expected_event_consumer = Mock() + patched_event_consumer.return_value = expected_event_consumer + + expected_json_formatter = Mock() + patched_json_formatter.return_value = expected_json_formatter + + consumer = generate_unformatted_consumer() + + self.assertEqual(expected_consumer, consumer) + + patched_decorated_consumer.assert_called_with([expected_json_formatter], expected_event_consumer) + patched_event_consumer.assert_called_with(True) + patched_json_formatter.assert_called_once() + + @patch("samcli.commands.logs.puller_factory.Colored") + @patch("samcli.commands.logs.puller_factory.ObservabilityEventConsumerDecorator") + @patch("samcli.commands.logs.puller_factory.CWColorizeErrorsFormatter") + @patch("samcli.commands.logs.puller_factory.CWJsonFormatter") + @patch("samcli.commands.logs.puller_factory.CWKeywordHighlighterFormatter") + @patch("samcli.commands.logs.puller_factory.CWPrettyPrintFormatter") + @patch("samcli.commands.logs.puller_factory.CWAddNewLineIfItDoesntExist") + @patch("samcli.commands.logs.puller_factory.CWConsoleEventConsumer") + def test_generate_console_consumer( + self, + patched_event_consumer, + patched_new_line_mapper, + patched_pretty_formatter, + patched_highlighter, + patched_json_formatter, + patched_errors_formatter, + patched_decorated_consumer, + patched_colored, + ): + mock_filter_pattern = Mock() + + expected_colored = Mock() + patched_colored.return_value = expected_colored + + expected_errors_formatter = Mock() + patched_errors_formatter.return_value = expected_errors_formatter + + expected_json_formatter = Mock() + patched_json_formatter.return_value = expected_json_formatter + + expected_highlighter = Mock() + patched_highlighter.return_value = expected_highlighter + + expected_pretty_formatter = Mock() + patched_pretty_formatter.return_value = expected_pretty_formatter + + expected_new_line_mapper = Mock() + patched_new_line_mapper.return_value = expected_new_line_mapper + + expected_event_consumer = Mock() + patched_event_consumer.return_value = expected_event_consumer + + expected_consumer = Mock() + patched_decorated_consumer.return_value = expected_consumer + + consumer = generate_console_consumer(mock_filter_pattern) + + self.assertEqual(expected_consumer, consumer) + + patched_colored.assert_called_once() + patched_event_consumer.assert_called_once() + patched_new_line_mapper.assert_called_once() + patched_pretty_formatter.assert_called_with(expected_colored) + patched_highlighter.assert_called_with(expected_colored, mock_filter_pattern) + patched_json_formatter.assert_called_once() + patched_errors_formatter.assert_called_with(expected_colored) + + patched_decorated_consumer.assert_called_with( + [ + expected_errors_formatter, + expected_json_formatter, + expected_highlighter, + expected_pretty_formatter, + expected_new_line_mapper, + ], + expected_event_consumer, + ) diff --git a/tests/unit/commands/samconfig/test_samconfig.py b/tests/unit/commands/samconfig/test_samconfig.py index 8c0d83b529..1e128d91ae 100644 --- a/tests/unit/commands/samconfig/test_samconfig.py +++ b/tests/unit/commands/samconfig/test_samconfig.py @@ -8,6 +8,7 @@ import tempfile from pathlib import Path from contextlib import contextmanager +from samcli.commands._utils.experimental import ExperimentalFlag, set_experimental from samcli.lib.config.samconfig import SamConfig, DEFAULT_ENV from click.testing import CliRunner @@ -537,10 +538,14 @@ def test_package_with_image_repository_and_image_repositories( self.assertIsNotNone(result.exception) @patch("samcli.lib.cli_validation.image_repository_validation.get_template_artifacts_format") + @patch("samcli.commands._utils.template.get_template_artifacts_format") + @patch("samcli.commands._utils.options.get_template_artifacts_format") @patch("samcli.commands.deploy.command.do_cli") - def test_deploy(self, do_cli_mock, get_template_artifacts_format_mock): + def test_deploy(self, do_cli_mock, template_artifacts_mock1, template_artifacts_mock2, template_artifacts_mock3): - get_template_artifacts_format_mock.return_value = [ZIP] + template_artifacts_mock1.return_value = [ZIP] + template_artifacts_mock2.return_value = [ZIP] + template_artifacts_mock3.return_value = [ZIP] config_values = { "template_file": "mytemplate.yaml", "stack_name": "mystack", @@ -647,10 +652,16 @@ def test_deploy_image_repositories_and_image_repository(self, do_cli_mock): self.assertIsNotNone(result.exception) @patch("samcli.lib.cli_validation.image_repository_validation.get_template_artifacts_format") + @patch("samcli.commands._utils.options.get_template_artifacts_format") + @patch("samcli.commands._utils.template.get_template_artifacts_format") @patch("samcli.commands.deploy.command.do_cli") - def test_deploy_different_parameter_override_format(self, do_cli_mock, get_template_artifacts_format_mock): + def test_deploy_different_parameter_override_format( + self, do_cli_mock, template_artifacts_mock1, template_artifacts_mock2, template_artifacts_mock3 + ): - get_template_artifacts_format_mock.return_value = [ZIP] + template_artifacts_mock1.return_value = [ZIP] + template_artifacts_mock2.return_value = [ZIP] + template_artifacts_mock3.return_value = [ZIP] config_values = { "template_file": "mytemplate.yaml", @@ -721,16 +732,20 @@ def test_deploy_different_parameter_override_format(self, do_cli_mock, get_templ True, ) + @patch("samcli.commands._utils.experimental.is_experimental_enabled") @patch("samcli.commands.logs.command.do_cli") - def test_logs(self, do_cli_mock): + def test_logs(self, do_cli_mock, experimental_mock): config_values = { - "name": "myfunction", + "name": ["myfunction"], "stack_name": "mystack", "filter": "myfilter", "tail": True, + "include_traces": False, "start_time": "starttime", "end_time": "endtime", + "region": "myregion", } + experimental_mock.return_value = False with samconfig_parameters(["logs"], self.scratch_dir, **config_values) as config_path: from samcli.commands.logs.command import cli @@ -745,7 +760,61 @@ def test_logs(self, do_cli_mock): LOG.exception("Command failed", exc_info=result.exc_info) self.assertIsNone(result.exception) - do_cli_mock.assert_called_with("myfunction", "mystack", "myfilter", True, "starttime", "endtime") + do_cli_mock.assert_called_with( + ("myfunction",), + "mystack", + "myfilter", + True, + False, + "starttime", + "endtime", + (), + False, + "myregion", + None, + ) + + @patch("samcli.commands._utils.experimental.is_experimental_enabled") + @patch("samcli.commands.logs.command.do_cli") + def test_logs_tail(self, do_cli_mock, experimental_mock): + config_values = { + "name": ["myfunction"], + "stack_name": "mystack", + "filter": "myfilter", + "tail": True, + "include_traces": True, + "start_time": "starttime", + "end_time": "endtime", + "cw_log_group": ["cw_log_group"], + "region": "myregion", + } + experimental_mock.return_value = True + with samconfig_parameters(["logs"], self.scratch_dir, **config_values) as config_path: + from samcli.commands.logs.command import cli + + LOG.debug(Path(config_path).read_text()) + runner = CliRunner() + result = runner.invoke(cli, []) + + LOG.info(result.output) + LOG.info(result.exception) + if result.exception: + LOG.exception("Command failed", exc_info=result.exc_info) + self.assertIsNone(result.exception) + + do_cli_mock.assert_called_with( + ("myfunction",), + "mystack", + "myfilter", + True, + True, + "starttime", + "endtime", + ("cw_log_group",), + False, + "myregion", + None, + ) @patch("samcli.commands.publish.command.do_cli") def test_publish(self, do_cli_mock): @@ -784,6 +853,87 @@ def test_info_must_not_read_from_config(self): info_result = json.loads(result.output) self.assertTrue("version" in info_result) + @patch("samcli.commands._utils.experimental.is_experimental_enabled") + @patch("samcli.lib.cli_validation.image_repository_validation.get_template_function_resource_ids") + @patch("samcli.lib.cli_validation.image_repository_validation.get_template_artifacts_format") + @patch("samcli.commands._utils.template.get_template_artifacts_format") + @patch("samcli.commands._utils.options.get_template_artifacts_format") + @patch("samcli.commands.sync.command.do_cli") + def test_sync( + self, + do_cli_mock, + template_artifacts_mock1, + template_artifacts_mock2, + template_artifacts_mock3, + template_artifacts_mock4, + experimental_mock, + ): + + template_artifacts_mock1.return_value = [ZIP] + template_artifacts_mock2.return_value = [ZIP] + template_artifacts_mock3.return_value = [ZIP] + template_artifacts_mock4.return_value = ["HelloWorldFunction"] + experimental_mock.return_value = True + + config_values = { + "template_file": "mytemplate.yaml", + "stack_name": "mystack", + "image_repository": "123456789012.dkr.ecr.us-east-1.amazonaws.com/test1", + "base_dir": "path", + "s3_prefix": "myprefix", + "kms_key_id": "mykms", + "parameter_overrides": 'Key1=Value1 Key2="Multiple spaces in the value"', + "capabilities": "cap1 cap2", + "no_execute_changeset": True, + "role_arn": "arn", + "notification_arns": "notify1 notify2", + "tags": 'a=tag1 b="tag with spaces"', + "metadata": '{"m1": "value1", "m2": "value2"}', + "guided": True, + "confirm_changeset": True, + "region": "myregion", + "signing_profiles": "function=profile:owner", + } + + with samconfig_parameters(["sync"], self.scratch_dir, **config_values) as config_path: + from samcli.commands.sync.command import cli + + LOG.debug(Path(config_path).read_text()) + runner = CliRunner() + result = runner.invoke(cli, []) + + LOG.info(result.output) + LOG.info(result.exception) + if result.exception: + LOG.exception("Command failed", exc_info=result.exc_info) + self.assertIsNone(result.exception) + + do_cli_mock.assert_called_with( + str(Path(os.getcwd(), "mytemplate.yaml")), + False, + False, + (), + (), + True, + "mystack", + "myregion", + None, + "path", + {"Key1": "Value1", "Key2": "Multiple spaces in the value"}, + None, + "123456789012.dkr.ecr.us-east-1.amazonaws.com/test1", + None, + "myprefix", + "mykms", + ["cap1", "cap2"], + "arn", + ["notify1", "notify2"], + {"a": "tag1", "b": "tag with spaces"}, + {"m1": "value1", "m2": "value2"}, + "samconfig.toml", + "default", + ) + class TestSamConfigWithOverrides(TestCase): def setUp(self): diff --git a/tests/unit/commands/sync/__init__.py b/tests/unit/commands/sync/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/commands/sync/test_command.py b/tests/unit/commands/sync/test_command.py new file mode 100644 index 0000000000..b5707249a2 --- /dev/null +++ b/tests/unit/commands/sync/test_command.py @@ -0,0 +1,614 @@ +from unittest import TestCase +from unittest.mock import ANY, MagicMock, Mock, patch +from parameterized import parameterized + +from samcli.commands.sync.command import do_cli, execute_code_sync, execute_watch +from samcli.lib.providers.provider import ResourceIdentifier +from samcli.commands._utils.options import ( + DEFAULT_BUILD_DIR, + DEFAULT_CACHE_DIR, + DEFAULT_BUILD_DIR_WITH_AUTO_DEPENDENCY_LAYER, +) + + +def get_mock_sam_config(): + mock_sam_config = MagicMock() + mock_sam_config.exists = MagicMock(return_value=True) + return mock_sam_config + + +MOCK_SAM_CONFIG = get_mock_sam_config() + + +class TestDoCli(TestCase): + def setUp(self): + + self.template_file = "input-template-file" + self.stack_name = "stack-name" + self.resource_id = [] + self.resource = [] + self.image_repository = "123456789012.dkr.ecr.us-east-1.amazonaws.com/test1" + self.image_repositories = None + self.mode = "mode" + self.s3_prefix = "s3-prefix" + self.kms_key_id = "kms-key-id" + self.notification_arns = [] + self.parameter_overrides = {"a": "b"} + self.capabilities = ("CAPABILITY_IAM",) + self.tags = {"c": "d"} + self.role_arn = "role_arn" + self.metadata = {} + self.region = None + self.profile = None + self.base_dir = None + self.clean = True + self.config_env = "mock-default-env" + self.config_file = "mock-default-filename" + MOCK_SAM_CONFIG.reset_mock() + + @parameterized.expand([(False, False, True), (False, False, False)]) + @patch("samcli.commands.sync.command.click") + @patch("samcli.commands.sync.command.execute_code_sync") + @patch("samcli.commands.build.command.click") + @patch("samcli.commands.build.build_context.BuildContext") + @patch("samcli.commands.package.command.click") + @patch("samcli.commands.package.package_context.PackageContext") + @patch("samcli.commands.deploy.command.click") + @patch("samcli.commands.deploy.deploy_context.DeployContext") + @patch("samcli.commands.build.command.os") + @patch("samcli.commands.sync.command.manage_stack") + def test_infra_must_succeed_sync( + self, + code, + watch, + auto_dependency_layer, + manage_stack_mock, + os_mock, + DeployContextMock, + mock_deploy_click, + PackageContextMock, + mock_package_click, + BuildContextMock, + mock_build_click, + execute_code_sync_mock, + click_mock, + ): + + build_context_mock = Mock() + BuildContextMock.return_value.__enter__.return_value = build_context_mock + package_context_mock = Mock() + PackageContextMock.return_value.__enter__.return_value = package_context_mock + deploy_context_mock = Mock() + DeployContextMock.return_value.__enter__.return_value = deploy_context_mock + + do_cli( + self.template_file, + False, + False, + self.resource_id, + self.resource, + auto_dependency_layer, + self.stack_name, + self.region, + self.profile, + self.base_dir, + self.parameter_overrides, + self.mode, + self.image_repository, + self.image_repositories, + self.s3_prefix, + self.kms_key_id, + self.capabilities, + self.role_arn, + self.notification_arns, + self.tags, + self.metadata, + self.config_file, + self.config_env, + ) + + build_dir = DEFAULT_BUILD_DIR_WITH_AUTO_DEPENDENCY_LAYER if auto_dependency_layer else DEFAULT_BUILD_DIR + BuildContextMock.assert_called_with( + resource_identifier=None, + template_file=self.template_file, + base_dir=self.base_dir, + build_dir=build_dir, + cache_dir=DEFAULT_CACHE_DIR, + clean=True, + use_container=False, + parallel=True, + parameter_overrides=self.parameter_overrides, + mode=self.mode, + cached=True, + create_auto_dependency_layer=auto_dependency_layer, + stack_name=self.stack_name, + ) + + PackageContextMock.assert_called_with( + template_file=ANY, + s3_bucket=ANY, + image_repository=self.image_repository, + image_repositories=self.image_repositories, + s3_prefix=self.s3_prefix, + kms_key_id=self.kms_key_id, + output_template_file=ANY, + no_progressbar=True, + metadata=self.metadata, + region=self.region, + profile=self.profile, + use_json=False, + force_upload=True, + ) + + DeployContextMock.assert_called_with( + template_file=ANY, + stack_name=self.stack_name, + s3_bucket=ANY, + image_repository=self.image_repository, + image_repositories=self.image_repositories, + no_progressbar=True, + s3_prefix=self.s3_prefix, + kms_key_id=self.kms_key_id, + parameter_overrides=self.parameter_overrides, + capabilities=self.capabilities, + role_arn=self.role_arn, + notification_arns=self.notification_arns, + tags=self.tags, + region=self.region, + profile=self.profile, + no_execute_changeset=True, + fail_on_empty_changeset=True, + confirm_changeset=False, + use_changeset=False, + force_upload=True, + signing_profiles=None, + disable_rollback=False, + ) + package_context_mock.run.assert_called_once_with() + deploy_context_mock.run.assert_called_once_with() + execute_code_sync_mock.assert_not_called() + + @parameterized.expand([(False, True, False)]) + @patch("samcli.commands.sync.command.click") + @patch("samcli.commands.sync.command.execute_watch") + @patch("samcli.commands.build.command.click") + @patch("samcli.commands.build.build_context.BuildContext") + @patch("samcli.commands.package.command.click") + @patch("samcli.commands.package.package_context.PackageContext") + @patch("samcli.commands.deploy.command.click") + @patch("samcli.commands.deploy.deploy_context.DeployContext") + @patch("samcli.commands.build.command.os") + @patch("samcli.commands.sync.command.manage_stack") + def test_watch_must_succeed_sync( + self, + code, + watch, + auto_dependency_layer, + manage_stack_mock, + os_mock, + DeployContextMock, + mock_deploy_click, + PackageContextMock, + mock_package_click, + BuildContextMock, + mock_build_click, + execute_watch_mock, + click_mock, + ): + + build_context_mock = Mock() + BuildContextMock.return_value.__enter__.return_value = build_context_mock + package_context_mock = Mock() + PackageContextMock.return_value.__enter__.return_value = package_context_mock + deploy_context_mock = Mock() + DeployContextMock.return_value.__enter__.return_value = deploy_context_mock + + do_cli( + self.template_file, + False, + True, + self.resource_id, + self.resource, + auto_dependency_layer, + self.stack_name, + self.region, + self.profile, + self.base_dir, + self.parameter_overrides, + self.mode, + self.image_repository, + self.image_repositories, + self.s3_prefix, + self.kms_key_id, + self.capabilities, + self.role_arn, + self.notification_arns, + self.tags, + self.metadata, + self.config_file, + self.config_env, + ) + + BuildContextMock.assert_called_with( + resource_identifier=None, + template_file=self.template_file, + base_dir=self.base_dir, + build_dir=DEFAULT_BUILD_DIR, + cache_dir=DEFAULT_CACHE_DIR, + clean=True, + use_container=False, + parallel=True, + parameter_overrides=self.parameter_overrides, + mode=self.mode, + cached=True, + create_auto_dependency_layer=auto_dependency_layer, + stack_name=self.stack_name, + ) + + PackageContextMock.assert_called_with( + template_file=ANY, + s3_bucket=ANY, + image_repository=self.image_repository, + image_repositories=self.image_repositories, + s3_prefix=self.s3_prefix, + kms_key_id=self.kms_key_id, + output_template_file=ANY, + no_progressbar=True, + metadata=self.metadata, + region=self.region, + profile=self.profile, + use_json=False, + force_upload=True, + ) + + DeployContextMock.assert_called_with( + template_file=ANY, + stack_name=self.stack_name, + s3_bucket=ANY, + image_repository=self.image_repository, + image_repositories=self.image_repositories, + no_progressbar=True, + s3_prefix=self.s3_prefix, + kms_key_id=self.kms_key_id, + parameter_overrides=self.parameter_overrides, + capabilities=self.capabilities, + role_arn=self.role_arn, + notification_arns=self.notification_arns, + tags=self.tags, + region=self.region, + profile=self.profile, + no_execute_changeset=True, + fail_on_empty_changeset=True, + confirm_changeset=False, + use_changeset=False, + force_upload=True, + signing_profiles=None, + disable_rollback=False, + ) + execute_watch_mock.assert_called_once_with( + self.template_file, build_context_mock, package_context_mock, deploy_context_mock, auto_dependency_layer + ) + + @parameterized.expand([(True, False, True)]) + @patch("samcli.commands.sync.command.click") + @patch("samcli.commands.sync.command.execute_code_sync") + @patch("samcli.commands.build.command.click") + @patch("samcli.commands.build.build_context.BuildContext") + @patch("samcli.commands.package.command.click") + @patch("samcli.commands.package.package_context.PackageContext") + @patch("samcli.commands.deploy.command.click") + @patch("samcli.commands.deploy.deploy_context.DeployContext") + @patch("samcli.commands.build.command.os") + @patch("samcli.commands.sync.command.manage_stack") + def test_code_must_succeed_sync( + self, + code, + watch, + auto_dependency_layer, + manage_stack_mock, + os_mock, + DeployContextMock, + mock_deploy_click, + PackageContextMock, + mock_package_click, + BuildContextMock, + mock_build_click, + execute_code_sync_mock, + click_mock, + ): + + build_context_mock = Mock() + BuildContextMock.return_value.__enter__.return_value = build_context_mock + package_context_mock = Mock() + PackageContextMock.return_value.__enter__.return_value = package_context_mock + deploy_context_mock = Mock() + DeployContextMock.return_value.__enter__.return_value = deploy_context_mock + + do_cli( + self.template_file, + True, + False, + self.resource_id, + self.resource, + auto_dependency_layer, + self.stack_name, + self.region, + self.profile, + self.base_dir, + self.parameter_overrides, + self.mode, + self.image_repository, + self.image_repositories, + self.s3_prefix, + self.kms_key_id, + self.capabilities, + self.role_arn, + self.notification_arns, + self.tags, + self.metadata, + self.config_file, + self.config_env, + ) + execute_code_sync_mock.assert_called_once_with( + self.template_file, + build_context_mock, + deploy_context_mock, + self.resource_id, + self.resource, + auto_dependency_layer, + ) + + +class TestSyncCode(TestCase): + def setUp(self) -> None: + self.template_file = "template.yaml" + self.build_context = MagicMock() + self.deploy_context = MagicMock() + + @patch("samcli.commands.sync.command.click") + @patch("samcli.commands.sync.command.SamLocalStackProvider.get_stacks") + @patch("samcli.commands.sync.command.SyncFlowFactory") + @patch("samcli.commands.sync.command.SyncFlowExecutor") + @patch("samcli.commands.sync.command.get_unique_resource_ids") + def test_execute_code_sync_single_resource( + self, + get_unique_resource_ids_mock, + sync_flow_executor_mock, + sync_flow_factory_mock, + get_stacks_mock, + click_mock, + ): + + resource_identifier_strings = ["Function1"] + resource_types = [] + sync_flows = [MagicMock()] + sync_flow_factory_mock.return_value.create_sync_flow.side_effect = sync_flows + get_unique_resource_ids_mock.return_value = { + ResourceIdentifier("Function1"), + } + + execute_code_sync( + self.template_file, + self.build_context, + self.deploy_context, + resource_identifier_strings, + resource_types, + True, + ) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_called_once_with(ResourceIdentifier("Function1")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_called_once_with(sync_flows[0]) + + get_unique_resource_ids_mock.assert_called_once_with( + get_stacks_mock.return_value[0], resource_identifier_strings, [] + ) + + @patch("samcli.commands.sync.command.click") + @patch("samcli.commands.sync.command.SamLocalStackProvider.get_stacks") + @patch("samcli.commands.sync.command.SyncFlowFactory") + @patch("samcli.commands.sync.command.SyncFlowExecutor") + @patch("samcli.commands.sync.command.get_unique_resource_ids") + def test_execute_code_sync_multiple_resource( + self, + get_unique_resource_ids_mock, + sync_flow_executor_mock, + sync_flow_factory_mock, + get_stacks_mock, + click_mock, + ): + + resource_identifier_strings = ["Function1", "Function2"] + resource_types = [] + sync_flows = [MagicMock(), MagicMock()] + sync_flow_factory_mock.return_value.create_sync_flow.side_effect = sync_flows + get_unique_resource_ids_mock.return_value = { + ResourceIdentifier("Function1"), + ResourceIdentifier("Function2"), + } + + execute_code_sync( + self.template_file, + self.build_context, + self.deploy_context, + resource_identifier_strings, + resource_types, + True, + ) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_any_call(ResourceIdentifier("Function1")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_any_call(sync_flows[0]) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_any_call(ResourceIdentifier("Function2")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_any_call(sync_flows[1]) + + self.assertEqual(sync_flow_factory_mock.return_value.create_sync_flow.call_count, 2) + self.assertEqual(sync_flow_executor_mock.return_value.add_sync_flow.call_count, 2) + + get_unique_resource_ids_mock.assert_called_once_with( + get_stacks_mock.return_value[0], resource_identifier_strings, [] + ) + + @patch("samcli.commands.sync.command.click") + @patch("samcli.commands.sync.command.SamLocalStackProvider.get_stacks") + @patch("samcli.commands.sync.command.SyncFlowFactory") + @patch("samcli.commands.sync.command.SyncFlowExecutor") + @patch("samcli.commands.sync.command.get_unique_resource_ids") + def test_execute_code_sync_single_type_resource( + self, + get_unique_resource_ids_mock, + sync_flow_executor_mock, + sync_flow_factory_mock, + get_stacks_mock, + click_mock, + ): + + resource_identifier_strings = ["Function1", "Function2"] + resource_types = ["Type1"] + sync_flows = [MagicMock(), MagicMock(), MagicMock()] + sync_flow_factory_mock.return_value.create_sync_flow.side_effect = sync_flows + get_unique_resource_ids_mock.return_value = { + ResourceIdentifier("Function1"), + ResourceIdentifier("Function2"), + ResourceIdentifier("Function3"), + } + execute_code_sync( + self.template_file, + self.build_context, + self.deploy_context, + resource_identifier_strings, + resource_types, + True, + ) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_any_call(ResourceIdentifier("Function1")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_any_call(sync_flows[0]) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_any_call(ResourceIdentifier("Function2")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_any_call(sync_flows[1]) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_any_call(ResourceIdentifier("Function3")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_any_call(sync_flows[2]) + + self.assertEqual(sync_flow_factory_mock.return_value.create_sync_flow.call_count, 3) + self.assertEqual(sync_flow_executor_mock.return_value.add_sync_flow.call_count, 3) + + get_unique_resource_ids_mock.assert_called_once_with( + get_stacks_mock.return_value[0], resource_identifier_strings, ["Type1"] + ) + + @patch("samcli.commands.sync.command.click") + @patch("samcli.commands.sync.command.SamLocalStackProvider.get_stacks") + @patch("samcli.commands.sync.command.SyncFlowFactory") + @patch("samcli.commands.sync.command.SyncFlowExecutor") + @patch("samcli.commands.sync.command.get_unique_resource_ids") + def test_execute_code_sync_multiple_type_resource( + self, + get_unique_resource_ids_mock, + sync_flow_executor_mock, + sync_flow_factory_mock, + get_stacks_mock, + click_mock, + ): + resource_identifier_strings = ["Function1", "Function2"] + resource_types = ["Type1", "Type2"] + sync_flows = [MagicMock(), MagicMock(), MagicMock(), MagicMock()] + sync_flow_factory_mock.return_value.create_sync_flow.side_effect = sync_flows + get_unique_resource_ids_mock.return_value = { + ResourceIdentifier("Function1"), + ResourceIdentifier("Function2"), + ResourceIdentifier("Function3"), + ResourceIdentifier("Function4"), + } + execute_code_sync( + self.template_file, + self.build_context, + self.deploy_context, + resource_identifier_strings, + resource_types, + True, + ) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_any_call(ResourceIdentifier("Function1")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_any_call(sync_flows[0]) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_any_call(ResourceIdentifier("Function2")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_any_call(sync_flows[1]) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_any_call(ResourceIdentifier("Function3")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_any_call(sync_flows[2]) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_any_call(ResourceIdentifier("Function4")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_any_call(sync_flows[3]) + + self.assertEqual(sync_flow_factory_mock.return_value.create_sync_flow.call_count, 4) + self.assertEqual(sync_flow_executor_mock.return_value.add_sync_flow.call_count, 4) + + get_unique_resource_ids_mock.assert_any_call( + get_stacks_mock.return_value[0], resource_identifier_strings, ["Type1", "Type2"] + ) + + @patch("samcli.commands.sync.command.click") + @patch("samcli.commands.sync.command.SamLocalStackProvider.get_stacks") + @patch("samcli.commands.sync.command.SyncFlowFactory") + @patch("samcli.commands.sync.command.SyncFlowExecutor") + @patch("samcli.commands.sync.command.get_all_resource_ids") + def test_execute_code_sync_default_all_resources( + self, + get_all_resource_ids_mock, + sync_flow_executor_mock, + sync_flow_factory_mock, + get_stacks_mock, + click_mock, + ): + sync_flows = [MagicMock(), MagicMock(), MagicMock(), MagicMock()] + sync_flow_factory_mock.return_value.create_sync_flow.side_effect = sync_flows + get_all_resource_ids_mock.return_value = [ + ResourceIdentifier("Function1"), + ResourceIdentifier("Function2"), + ResourceIdentifier("Function3"), + ResourceIdentifier("Function4"), + ] + execute_code_sync(self.template_file, self.build_context, self.deploy_context, "", [], True) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_any_call(ResourceIdentifier("Function1")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_any_call(sync_flows[0]) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_any_call(ResourceIdentifier("Function2")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_any_call(sync_flows[1]) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_any_call(ResourceIdentifier("Function3")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_any_call(sync_flows[2]) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_any_call(ResourceIdentifier("Function4")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_any_call(sync_flows[3]) + + self.assertEqual(sync_flow_factory_mock.return_value.create_sync_flow.call_count, 4) + self.assertEqual(sync_flow_executor_mock.return_value.add_sync_flow.call_count, 4) + + get_all_resource_ids_mock.assert_called_once_with(get_stacks_mock.return_value[0]) + + +class TestWatch(TestCase): + def setUp(self) -> None: + self.template_file = "template.yaml" + self.build_context = MagicMock() + self.package_context = MagicMock() + self.deploy_context = MagicMock() + + @parameterized.expand([(True,), (False,)]) + @patch("samcli.commands.sync.command.click") + @patch("samcli.commands.sync.command.WatchManager") + def test_execute_watch( + self, + auto_dependency_layer, + watch_manager_mock, + click_mock, + ): + execute_watch( + self.template_file, self.build_context, self.package_context, self.deploy_context, auto_dependency_layer + ) + + watch_manager_mock.assert_called_once_with( + self.template_file, self.build_context, self.package_context, self.deploy_context, auto_dependency_layer + ) + watch_manager_mock.return_value.start.assert_called_once_with() diff --git a/tests/unit/commands/traces/test_command.py b/tests/unit/commands/traces/test_command.py new file mode 100644 index 0000000000..69d457eaa3 --- /dev/null +++ b/tests/unit/commands/traces/test_command.py @@ -0,0 +1,69 @@ +from unittest import TestCase +from unittest.mock import patch, call, Mock + +from parameterized import parameterized + +from samcli.commands.traces.command import do_cli + + +class TestTracesCommand(TestCase): + def setUp(self): + self.region = "region" + + @parameterized.expand( + [ + (None, None, None, False, None), + (["trace_id1", "trace_id2"], None, None, False, None), + (None, "start_time", None, False, None), + (None, "start_time", "end_time", False, None), + (None, None, None, True, None), + (None, None, None, True, "output_dir"), + ] + ) + @patch("samcli.commands.logs.logs_context.parse_time") + @patch("samcli.lib.utils.boto_utils.get_boto_config_with_user_agent") + @patch("boto3.client") + @patch("samcli.commands.traces.traces_puller_factory.generate_trace_puller") + def test_traces_command( + self, + trace_ids, + start_time, + end_time, + tail, + output_dir, + patched_generate_puller, + patched_boto3, + patched_get_boto_config_with_user_agent, + patched_parse_time, + ): + given_start_time = Mock() + given_end_time = Mock() + patched_parse_time.side_effect = [given_start_time, given_end_time] + + given_boto_config = Mock() + patched_get_boto_config_with_user_agent.return_value = given_boto_config + + given_xray_client = Mock() + patched_boto3.return_value = given_xray_client + + given_puller = Mock() + patched_generate_puller.return_value = given_puller + + do_cli(trace_ids, start_time, end_time, tail, output_dir, self.region) + + patched_parse_time.assert_has_calls( + [ + call(start_time, "start-time"), + call(end_time, "end-time"), + ] + ) + patched_get_boto_config_with_user_agent.assert_called_with(region_name=self.region) + patched_boto3.assert_called_with("xray", config=given_boto_config) + patched_generate_puller.assert_called_with(given_xray_client, output_dir) + + if trace_ids: + given_puller.load_events.assert_called_with(trace_ids) + elif tail: + given_puller.tail.assert_called_with(given_start_time) + else: + given_puller.load_time_period.assert_called_with(given_start_time, given_end_time) diff --git a/tests/unit/commands/traces/test_trace_console_consumers.py b/tests/unit/commands/traces/test_trace_console_consumers.py new file mode 100644 index 0000000000..cb98885239 --- /dev/null +++ b/tests/unit/commands/traces/test_trace_console_consumers.py @@ -0,0 +1,14 @@ +from unittest import TestCase +from unittest.mock import patch, Mock + +from samcli.commands.traces.trace_console_consumers import XRayTraceConsoleConsumer + + +class TestTraceConsoleConsumers(TestCase): + @patch("samcli.commands.traces.trace_console_consumers.click") + def test_console_consumer(self, patched_click): + event = Mock() + consumer = XRayTraceConsoleConsumer() + consumer.consume(event) + + patched_click.echo.assert_called_with(event.message) diff --git a/tests/unit/commands/traces/test_traces_puller_factory.py b/tests/unit/commands/traces/test_traces_puller_factory.py new file mode 100644 index 0000000000..10c0ad6c26 --- /dev/null +++ b/tests/unit/commands/traces/test_traces_puller_factory.py @@ -0,0 +1,87 @@ +from unittest import TestCase +from unittest.mock import patch, Mock + +from parameterized import parameterized + +from samcli.commands.traces.traces_puller_factory import ( + generate_trace_puller, + generate_unformatted_xray_event_consumer, + generate_xray_event_console_consumer, +) + + +class TestGenerateTracePuller(TestCase): + @parameterized.expand( + [ + (False,), + (True,), + ] + ) + @patch("samcli.commands.traces.traces_puller_factory.generate_xray_event_console_consumer") + @patch("samcli.commands.traces.traces_puller_factory.generate_unformatted_xray_event_consumer") + @patch("samcli.commands.traces.traces_puller_factory.XRayTracePuller") + @patch("samcli.commands.traces.traces_puller_factory.XRayServiceGraphPuller") + @patch("samcli.commands.traces.traces_puller_factory.ObservabilityCombinedPuller") + def test_generate_trace_puller( + self, + unformatted, + patched_combine_puller, + patched_xray_service_graph_puller, + patched_xray_trace_puller, + patched_generate_unformatted_consumer, + patched_generate_console_consumer, + ): + given_xray_client = Mock() + given_xray_trace_puller = Mock() + given_xray_service_graph_puller = Mock() + given_combine_puller = Mock() + patched_xray_trace_puller.return_value = given_xray_trace_puller + patched_xray_service_graph_puller.return_value = given_xray_service_graph_puller + patched_combine_puller.return_value = given_combine_puller + + given_console_consumer = Mock() + patched_generate_console_consumer.return_value = given_console_consumer + + given_file_consumer = Mock() + patched_generate_unformatted_consumer.return_value = given_file_consumer + + actual_puller = generate_trace_puller(given_xray_client, unformatted) + self.assertEqual(given_combine_puller, actual_puller) + + if unformatted: + patched_generate_unformatted_consumer.assert_called_with() + patched_xray_trace_puller.assert_called_with(given_xray_client, given_file_consumer) + else: + patched_generate_console_consumer.assert_called_once() + patched_xray_trace_puller.assert_called_with(given_xray_client, given_console_consumer) + + @patch("samcli.commands.traces.traces_puller_factory.ObservabilityEventConsumerDecorator") + @patch("samcli.commands.traces.traces_puller_factory.XRayTraceJSONMapper") + @patch("samcli.commands.traces.traces_puller_factory.XRayTraceConsoleConsumer") + def test_generate_file_consumer(self, patched_consumer, patched_trace_json_mapper, patched_consumer_decorator): + given_consumer = Mock() + patched_consumer_decorator.return_value = given_consumer + + actual_consumer = generate_unformatted_xray_event_consumer() + self.assertEqual(given_consumer, actual_consumer) + + patched_trace_json_mapper.assert_called_once() + patched_consumer.assert_called_with() + + @patch("samcli.commands.traces.traces_puller_factory.ObservabilityEventConsumerDecorator") + @patch("samcli.commands.traces.traces_puller_factory.XRayTraceConsoleMapper") + @patch("samcli.commands.traces.traces_puller_factory.XRayTraceConsoleConsumer") + def test_generate_console_consumer( + self, + patched_console_consumer, + patched_console_mapper, + patched_consumer_decorator, + ): + given_consumer = Mock() + patched_consumer_decorator.return_value = given_consumer + + actual_consumer = generate_xray_event_console_consumer() + self.assertEqual(given_consumer, actual_consumer) + + patched_console_mapper.assert_called_once() + patched_console_consumer.assert_called_once() diff --git a/tests/unit/lib/bootstrap/nested_stack/test_nested_stack_builder.py b/tests/unit/lib/bootstrap/nested_stack/test_nested_stack_builder.py new file mode 100644 index 0000000000..fa799a9db6 --- /dev/null +++ b/tests/unit/lib/bootstrap/nested_stack/test_nested_stack_builder.py @@ -0,0 +1,78 @@ +from unittest import TestCase + +from samcli.lib.bootstrap.nested_stack.nested_stack_builder import NestedStackBuilder +from samcli.lib.providers.provider import Function +from samcli.lib.utils.resources import AWS_SERVERLESS_LAYERVERSION +from tests.unit.lib.build_module.test_build_graph import generate_function + + +class TestNestedStackBuilder(TestCase): + def setUp(self) -> None: + self.nested_stack_builder = NestedStackBuilder() + + def test_no_function_added(self): + self.assertFalse(self.nested_stack_builder.is_any_function_added()) + + def test_with_function_added(self): + function_runtime = "runtime" + stack_name = "stack_name" + function_logical_id = "FunctionLogicalId" + layer_contents_folder = "layer/contents/folder" + + function = generate_function(name=function_logical_id, runtime=function_runtime) + self.nested_stack_builder.add_function(stack_name, layer_contents_folder, function) + + self.assertTrue(self.nested_stack_builder.is_any_function_added()) + + nested_template = self.nested_stack_builder.build_as_dict() + resources = nested_template.get("Resources", {}) + outputs = nested_template.get("Outputs", {}) + + self.assertEqual(len(resources), 1) + self.assertEqual(len(outputs), 1) + + layer_logical_id = list(resources.keys())[0] + self.assertTrue(layer_logical_id.startswith(function_logical_id)) + self.assertTrue(layer_logical_id.endswith("DepLayer")) + + layer_resource = list(resources.values())[0] + self.assertEqual(layer_resource.get("Type"), AWS_SERVERLESS_LAYERVERSION) + + layer_properties = layer_resource.get("Properties", {}) + layer_name = layer_properties.get("LayerName") + self.assertTrue(layer_name.startswith(stack_name)) + self.assertIn(function_logical_id, layer_name) + self.assertTrue(layer_name.endswith("DepLayer")) + + self.assertEqual(layer_properties.get("ContentUri"), layer_contents_folder) + self.assertEqual(layer_properties.get("RetentionPolicy"), "Delete") + self.assertIn(function_runtime, layer_properties.get("CompatibleRuntimes")) + + layer_output_key = list(outputs.keys())[0] + self.assertTrue(layer_output_key.startswith(function_logical_id)) + self.assertTrue(layer_output_key.endswith("DepLayer")) + + layer_output = list(outputs.values())[0] + self.assertIn("Value", layer_output.keys()) + + layer_output_value = layer_output.get("Value") + self.assertIn("Ref", layer_output_value) + self.assertEqual(layer_output_value.get("Ref"), layer_logical_id) + + def test_get_layer_logical_id(self): + function_logical_id = "function_logical_id" + layer_logical_id = NestedStackBuilder.get_layer_logical_id(function_logical_id) + + self.assertTrue(layer_logical_id.startswith(function_logical_id[:48])) + self.assertTrue(layer_logical_id.endswith("DepLayer")) + self.assertLessEqual(len(layer_logical_id), 64) + + def test_get_layer_name(self): + function_logical_id = "function_logical_id" + stack_name = "function_logical_id" + layer_name = NestedStackBuilder.get_layer_name(stack_name, function_logical_id) + + self.assertTrue(layer_name.startswith(stack_name[:16])) + self.assertTrue(layer_name.endswith("DepLayer")) + self.assertIn(function_logical_id[:22], layer_name) + self.assertLessEqual(len(layer_name), 64) diff --git a/tests/unit/lib/bootstrap/nested_stack/test_nested_stack_manager.py b/tests/unit/lib/bootstrap/nested_stack/test_nested_stack_manager.py new file mode 100644 index 0000000000..89f61bc89d --- /dev/null +++ b/tests/unit/lib/bootstrap/nested_stack/test_nested_stack_manager.py @@ -0,0 +1,207 @@ +import os +from unittest import TestCase +from unittest.mock import Mock, patch, ANY, call + +from parameterized import parameterized + +from samcli.lib.bootstrap.nested_stack.nested_stack_manager import ( + NESTED_STACK_NAME, + NestedStackManager, +) +from samcli.lib.build.app_builder import ApplicationBuildResult +from samcli.lib.sync.exceptions import InvalidRuntimeDefinitionForFunction +from samcli.lib.utils import osutils +from samcli.lib.utils.osutils import BUILD_DIR_PERMISSIONS +from samcli.lib.utils.resources import AWS_SQS_QUEUE, AWS_SERVERLESS_FUNCTION + + +class TestNestedStackManager(TestCase): + def setUp(self) -> None: + self.stack_name = "stack_name" + self.build_dir = "build_dir" + self.stack_location = "stack_location" + + def test_nothing_to_add(self): + template = {} + app_build_result = ApplicationBuildResult(Mock(), {}) + nested_stack_manager = NestedStackManager( + self.stack_name, self.build_dir, self.stack_location, template, app_build_result + ) + result = nested_stack_manager.generate_auto_dependency_layer_stack() + + self.assertEqual(template, result) + + def test_unsupported_resource(self): + template = {"Resources": {"MySqsQueue": {"Type": AWS_SQS_QUEUE}}} + app_build_result = ApplicationBuildResult(Mock(), {}) + nested_stack_manager = NestedStackManager( + self.stack_name, self.build_dir, self.stack_location, template, app_build_result + ) + result = nested_stack_manager.generate_auto_dependency_layer_stack() + + self.assertEqual(template, result) + + def test_image_function(self): + template = { + "Resources": { + "MyFunction": { + "Type": AWS_SERVERLESS_FUNCTION, + "Properties": {"Runtime": "unsupported_runtime", "PackageType": "IMAGE"}, + } + } + } + app_build_result = ApplicationBuildResult(Mock(), {"MyFunction": "path/to/build/dir"}) + nested_stack_manager = NestedStackManager( + self.stack_name, self.build_dir, self.stack_location, template, app_build_result + ) + result = nested_stack_manager.generate_auto_dependency_layer_stack() + + self.assertEqual(template, result) + + def test_unsupported_runtime(self): + template = { + "Resources": { + "MyFunction": {"Type": AWS_SERVERLESS_FUNCTION, "Properties": {"Runtime": "unsupported_runtime"}} + } + } + app_build_result = ApplicationBuildResult(Mock(), {"MyFunction": "path/to/build/dir"}) + nested_stack_manager = NestedStackManager( + self.stack_name, self.build_dir, self.stack_location, template, app_build_result + ) + result = nested_stack_manager.generate_auto_dependency_layer_stack() + + self.assertEqual(template, result) + + @patch("samcli.lib.bootstrap.nested_stack.nested_stack_manager.osutils") + def test_no_dependencies_dir(self, patched_osutils): + template = { + "Resources": {"MyFunction": {"Type": AWS_SERVERLESS_FUNCTION, "Properties": {"Runtime": "python3.8"}}} + } + build_graph = Mock() + build_graph.get_function_build_definition_with_logical_id.return_value = None + app_build_result = ApplicationBuildResult(build_graph, {"MyFunction": "path/to/build/dir"}) + nested_stack_manager = NestedStackManager( + self.stack_name, self.build_dir, self.stack_location, template, app_build_result + ) + result = nested_stack_manager.generate_auto_dependency_layer_stack() + + self.assertEqual(template, result) + + @patch("samcli.lib.bootstrap.nested_stack.nested_stack_manager.move_template") + @patch("samcli.lib.bootstrap.nested_stack.nested_stack_manager.osutils") + @patch("samcli.lib.bootstrap.nested_stack.nested_stack_manager.os.path.isdir") + def test_with_zip_function(self, patched_isdir, patched_osutils, patched_move_template): + template = { + "Resources": {"MyFunction": {"Type": AWS_SERVERLESS_FUNCTION, "Properties": {"Runtime": "python3.8"}}} + } + + # prepare build graph + dependencies_dir = Mock() + function = Mock() + function.name = "MyFunction" + functions = [function] + build_graph = Mock() + function_definition_mock = Mock(dependencies_dir=dependencies_dir, functions=functions) + build_graph.get_function_build_definition_with_logical_id.return_value = function_definition_mock + app_build_result = ApplicationBuildResult(build_graph, {"MyFunction": "path/to/build/dir"}) + patched_isdir.return_value = True + + nested_stack_manager = NestedStackManager( + self.stack_name, self.build_dir, self.stack_location, template, app_build_result + ) + + with patch.object(nested_stack_manager, "_add_layer_readme_info") as patched_add_readme: + result = nested_stack_manager.generate_auto_dependency_layer_stack() + + patched_move_template.assert_called_with( + self.stack_location, os.path.join(self.build_dir, "nested_template.yaml"), ANY + ) + self.assertNotEqual(template, result) + + resources = result.get("Resources") + self.assertIn(NESTED_STACK_NAME, resources.keys()) + + self.assertTrue(resources.get("MyFunction", {}).get("Properties", {}).get("Layers", [])) + + def test_adding_readme_file(self): + with patch("builtins.open") as patched_open: + dependencies_dir = "dependencies" + function_name = "function_name" + NestedStackManager._add_layer_readme_info(dependencies_dir, function_name) + patched_open.assert_has_calls( + [ + call(os.path.join(dependencies_dir, "AWS_SAM_CLI_README"), "w+"), + call() + .__enter__() + .write( + f"This layer contains dependencies of function {function_name} and automatically added by AWS SAM CLI command 'sam sync'" + ), + ], + any_order=True, + ) + + def test_update_layer_folder_raise_exception_with_no_runtime(self): + with self.assertRaises(InvalidRuntimeDefinitionForFunction): + NestedStackManager.update_layer_folder(Mock(), Mock(), Mock(), Mock(), None) + + @patch("samcli.lib.bootstrap.nested_stack.nested_stack_manager.Path") + @patch("samcli.lib.bootstrap.nested_stack.nested_stack_manager.shutil") + @patch("samcli.lib.bootstrap.nested_stack.nested_stack_manager.osutils") + @patch("samcli.lib.bootstrap.nested_stack.nested_stack_manager.NestedStackManager._add_layer_readme_info") + @patch("samcli.lib.bootstrap.nested_stack.nested_stack_manager.os.path.isdir") + def test_update_layer_folder( + self, patched_isdir, patched_add_layer_readme, patched_osutils, patched_shutil, patched_path + ): + build_dir = "build_dir" + dependencies_dir = "dependencies_dir" + layer_logical_id = "layer_logical_id" + function_logical_id = "function_logical_id" + function_runtime = "python3.9" + + layer_contents_folder = Mock() + layer_root_folder = Mock() + layer_root_folder.exists.return_value = True + layer_root_folder.joinpath.return_value = layer_contents_folder + patched_path.return_value.joinpath.return_value = layer_root_folder + patched_isdir.return_value = True + + layer_folder = NestedStackManager.update_layer_folder( + build_dir, dependencies_dir, layer_logical_id, function_logical_id, function_runtime + ) + + patched_shutil.rmtree.assert_called_with(layer_root_folder) + layer_contents_folder.mkdir.assert_called_with(BUILD_DIR_PERMISSIONS, parents=True) + patched_osutils.copytree.assert_called_with(dependencies_dir, str(layer_contents_folder)) + patched_add_layer_readme.assert_called_with(str(layer_root_folder), function_logical_id) + self.assertEqual(layer_folder, str(layer_root_folder)) + + @patch("samcli.lib.bootstrap.nested_stack.nested_stack_manager.Path") + @patch("samcli.lib.bootstrap.nested_stack.nested_stack_manager.shutil") + @patch("samcli.lib.bootstrap.nested_stack.nested_stack_manager.osutils") + @patch("samcli.lib.bootstrap.nested_stack.nested_stack_manager.NestedStackManager._add_layer_readme_info") + @patch("samcli.lib.bootstrap.nested_stack.nested_stack_manager.os.path.isdir") + def test_skipping_dependency_copy_when_function_has_no_dependencies( + self, patched_isdir, patched_add_layer_readme, patched_osutils, patched_shutil, patched_path + ): + build_dir = "build_dir" + dependencies_dir = "dependencies_dir" + layer_logical_id = "layer_logical_id" + function_logical_id = "function_logical_id" + function_runtime = "python3.9" + + layer_contents_folder = Mock() + layer_root_folder = Mock() + layer_root_folder.exists.return_value = True + layer_root_folder.joinpath.return_value = layer_contents_folder + patched_path.return_value.joinpath.return_value = layer_root_folder + + patched_isdir.return_value = False + + NestedStackManager.update_layer_folder( + build_dir, dependencies_dir, layer_logical_id, function_logical_id, function_runtime + ) + patched_osutils.copytree.assert_not_called() + + @parameterized.expand([("python3.8", True), ("ruby2.7", False)]) + def test_is_runtime_supported(self, runtime, supported): + self.assertEqual(NestedStackManager.is_runtime_supported(runtime), supported) diff --git a/tests/unit/lib/build_module/test_app_builder.py b/tests/unit/lib/build_module/test_app_builder.py index a933ac925e..5e8c7752ec 100644 --- a/tests/unit/lib/build_module/test_app_builder.py +++ b/tests/unit/lib/build_module/test_app_builder.py @@ -89,6 +89,8 @@ def build_layer_return( layer_build_architecture, artifact_dir, layer_env_vars, + dependencies_dir, + download_dependencies, ): return f"{layer_name}_location" @@ -104,7 +106,7 @@ def build_layer_return( build_image_function_mock_return, ] - result = self.builder.build() + result = self.builder.build().artifacts self.maxDiff = None self.assertEqual( @@ -130,6 +132,8 @@ def build_layer_return( ANY, self.func1.metadata, ANY, + ANY, + True, ), call( self.func2.name, @@ -141,6 +145,8 @@ def build_layer_return( ANY, self.func2.metadata, ANY, + ANY, + True, ), call( self.imageFunc1.name, @@ -152,6 +158,8 @@ def build_layer_return( ANY, self.imageFunc1.metadata, ANY, + ANY, + True, ), ], any_order=False, @@ -167,6 +175,8 @@ def build_layer_return( self.layer1.build_architecture, ANY, ANY, + ANY, + True, ), call( self.layer2.name, @@ -176,6 +186,8 @@ def build_layer_return( self.layer2.build_architecture, ANY, ANY, + ANY, + True, ), ] ) @@ -183,10 +195,10 @@ def build_layer_return( @patch("samcli.lib.build.build_graph.BuildGraph._write") def test_should_use_function_or_layer_get_build_dir_to_determine_artifact_dir(self, persist_mock): def get_func_call_with_artifact_dir(artifact_dir): - return call(ANY, ANY, ANY, ANY, ANY, ANY, artifact_dir, ANY, ANY) + return call(ANY, ANY, ANY, ANY, ANY, ANY, artifact_dir, ANY, ANY, ANY, True) def get_layer_call_with_artifact_dir(artifact_dir): - return call(ANY, ANY, ANY, ANY, ANY, artifact_dir, ANY) + return call(ANY, ANY, ANY, ANY, ANY, artifact_dir, ANY, ANY, True) build_function_mock = Mock() build_layer_mock = Mock() @@ -256,7 +268,7 @@ def test_should_run_build_for_only_unique_builds(self, persist_mock, read_mock, function1_2.get_build_dir(build_dir), ] - result = builder.build() + result = builder.build().artifacts # result should contain all 3 functions as expected self.assertEqual( @@ -281,6 +293,8 @@ def test_should_run_build_for_only_unique_builds(self, persist_mock, read_mock, ANY, function1_1.metadata, ANY, + ANY, + True, ), call( function2.name, @@ -292,6 +306,8 @@ def test_should_run_build_for_only_unique_builds(self, persist_mock, read_mock, ANY, function2.metadata, ANY, + ANY, + True, ), ], any_order=True, @@ -308,15 +324,18 @@ def test_default_run_should_pick_default_strategy(self, mock_default_build_strat builder = ApplicationBuilder(Mock(), "builddir", "basedir", "cachedir") builder._get_build_graph = get_build_graph_mock - result = builder.build() + result = builder.build().artifacts mock_default_build_strategy.build.assert_called_once() self.assertEqual(result, mock_default_build_strategy.build()) - @patch("samcli.lib.build.app_builder.CachedBuildStrategy") - def test_cached_run_should_pick_cached_strategy(self, mock_cached_build_strategy_class): - mock_cached_build_strategy = Mock() - mock_cached_build_strategy_class.return_value = mock_cached_build_strategy + @patch("samcli.lib.build.app_builder.CachedOrIncrementalBuildStrategyWrapper") + def test_cached_run_should_pick_incremental_strategy( + self, + mock_cached_and_incremental_build_strategy_class, + ): + mock_cached_and_incremental_build_strategy = Mock() + mock_cached_and_incremental_build_strategy_class.return_value = mock_cached_and_incremental_build_strategy build_graph_mock = Mock() get_build_graph_mock = Mock(return_value=build_graph_mock) @@ -324,10 +343,10 @@ def test_cached_run_should_pick_cached_strategy(self, mock_cached_build_strategy builder = ApplicationBuilder(Mock(), "builddir", "basedir", "cachedir", cached=True) builder._get_build_graph = get_build_graph_mock - result = builder.build() + result = builder.build().artifacts - mock_cached_build_strategy.build.assert_called_once() - self.assertEqual(result, mock_cached_build_strategy.build()) + mock_cached_and_incremental_build_strategy.build.assert_called_once() + self.assertEqual(result, mock_cached_and_incremental_build_strategy.build()) @patch("samcli.lib.build.app_builder.ParallelBuildStrategy") def test_parallel_run_should_pick_parallel_strategy(self, mock_parallel_build_strategy_class): @@ -340,29 +359,32 @@ def test_parallel_run_should_pick_parallel_strategy(self, mock_parallel_build_st builder = ApplicationBuilder(Mock(), "builddir", "basedir", "cachedir", parallel=True) builder._get_build_graph = get_build_graph_mock - result = builder.build() + result = builder.build().artifacts mock_parallel_build_strategy.build.assert_called_once() self.assertEqual(result, mock_parallel_build_strategy.build()) @patch("samcli.lib.build.app_builder.ParallelBuildStrategy") - @patch("samcli.lib.build.app_builder.CachedBuildStrategy") - def test_parallel_and_cached_run_should_pick_parallel_with_cached_strategy( - self, mock_cached_build_strategy_class, mock_parallel_build_strategy_class + @patch("samcli.lib.build.app_builder.CachedOrIncrementalBuildStrategyWrapper") + def test_parallel_and_cached_run_should_pick_parallel_with_incremental( + self, + mock_cached_and_incremental_build_strategy_class, + mock_parallel_build_strategy_class, ): + mock_cached_and_incremental_build_strategy = Mock() + mock_cached_and_incremental_build_strategy_class.return_value = mock_cached_and_incremental_build_strategy mock_parallel_build_strategy = Mock() mock_parallel_build_strategy_class.return_value = mock_parallel_build_strategy - mock_cached_build_strategy = Mock() - mock_cached_build_strategy_class.return_value = mock_cached_build_strategy - build_graph_mock = Mock() get_build_graph_mock = Mock(return_value=build_graph_mock) - builder = ApplicationBuilder(Mock(), "builddir", "basedir", "cachedir", parallel=True) + builder = ApplicationBuilder(Mock(), "builddir", "basedir", "cachedir", parallel=True, cached=True) builder._get_build_graph = get_build_graph_mock - result = builder.build() + result = builder.build().artifacts + + mock_parallel_build_strategy_class.assert_called_once_with(ANY, mock_cached_and_incremental_build_strategy) mock_parallel_build_strategy.build.assert_called_once() self.assertEqual(result, mock_parallel_build_strategy.build()) @@ -455,6 +477,8 @@ def test_must_build_layer_in_process(self, get_layer_subfolder_mock, osutils_moc "python3.8", ARM64, None, + None, + True, ) @patch("samcli.lib.build.app_builder.get_workflow_config") @@ -950,7 +974,7 @@ def test_must_build_in_process(self, osutils_mock, get_workflow_config_mock): self.builder._build_function(function_name, codeuri, ZIP, runtime, architecture, handler, artifacts_dir) self.builder._build_function_in_process.assert_called_with( - config_mock, code_dir, artifacts_dir, scratch_dir, manifest_path, runtime, architecture, None + config_mock, code_dir, artifacts_dir, scratch_dir, manifest_path, runtime, architecture, None, None, True ) @patch("samcli.lib.build.app_builder.get_workflow_config") @@ -991,7 +1015,7 @@ def test_must_build_in_process_with_metadata(self, osutils_mock, get_workflow_co ) self.builder._build_function_in_process.assert_called_with( - config_mock, code_dir, artifacts_dir, scratch_dir, manifest_path, runtime, architecture, None + config_mock, code_dir, artifacts_dir, scratch_dir, manifest_path, runtime, architecture, None, None, True ) @patch("samcli.lib.build.app_builder.get_workflow_config") @@ -1145,7 +1169,16 @@ def test_must_use_lambda_builder(self, lambda_builder_mock): builder_instance_mock = lambda_builder_mock.return_value = Mock() result = self.builder._build_function_in_process( - config_mock, "source_dir", "artifacts_dir", "scratch_dir", "manifest_path", "runtime", X86_64, None + config_mock, + "source_dir", + "artifacts_dir", + "scratch_dir", + "manifest_path", + "runtime", + X86_64, + None, + None, + True, ) self.assertEqual(result, "artifacts_dir") @@ -1165,6 +1198,9 @@ def test_must_use_lambda_builder(self, lambda_builder_mock): mode="mode", options=None, architecture=X86_64, + dependencies_dir=None, + download_dependencies=True, + combine_dependencies=True, ) @patch("samcli.lib.build.app_builder.LambdaBuilder") @@ -1176,7 +1212,16 @@ def test_must_raise_on_error(self, lambda_builder_mock): with self.assertRaises(BuildError): self.builder._build_function_in_process( - config_mock, "source_dir", "artifacts_dir", "scratch_dir", "manifest_path", "runtime", X86_64, None + config_mock, + "source_dir", + "artifacts_dir", + "scratch_dir", + "manifest_path", + "runtime", + X86_64, + None, + None, + True, ) diff --git a/tests/unit/lib/build_module/test_build_graph.py b/tests/unit/lib/build_module/test_build_graph.py index 9065177941..8dd970aa9f 100644 --- a/tests/unit/lib/build_module/test_build_graph.py +++ b/tests/unit/lib/build_module/test_build_graph.py @@ -1,10 +1,12 @@ from unittest import TestCase +from unittest.mock import patch, Mock from uuid import uuid4 from pathlib import Path import tomlkit from samcli.lib.utils.architecture import X86_64, ARM64 from parameterized import parameterized +from typing import Dict, cast from samcli.lib.build.build_graph import ( FunctionBuildDefinition, @@ -15,7 +17,7 @@ PACKAGETYPE_FIELD, METADATA_FIELD, FUNCTIONS_FIELD, - SOURCE_MD5_FIELD, + SOURCE_HASH_FIELD, ENV_VARS_FIELD, LAYER_NAME_FIELD, BUILD_METHOD_FIELD, @@ -27,6 +29,8 @@ BuildGraph, InvalidBuildGraphException, LayerBuildDefinition, + MANIFEST_HASH_FIELD, + BuildHashingInformation, ) from samcli.lib.providers.provider import Function, LayerVersion from samcli.lib.utils import osutils @@ -98,7 +102,14 @@ def generate_layer( class TestConversionFunctions(TestCase): def test_function_build_definition_to_toml_table(self): build_definition = FunctionBuildDefinition( - "runtime", "codeuri", ZIP, X86_64, {"key": "value"}, "source_md5", env_vars={"env_vars": "value1"} + "runtime", + "codeuri", + ZIP, + X86_64, + {"key": "value"}, + "source_hash", + "manifest_hash", + env_vars={"env_vars": "value1"}, ) build_definition.add_function(generate_function()) @@ -108,13 +119,21 @@ def test_function_build_definition_to_toml_table(self): self.assertEqual(toml_table[RUNTIME_FIELD], build_definition.runtime) self.assertEqual(toml_table[METADATA_FIELD], build_definition.metadata) self.assertEqual(toml_table[FUNCTIONS_FIELD], [f.name for f in build_definition.functions]) - self.assertEqual(toml_table[SOURCE_MD5_FIELD], build_definition.source_md5) + self.assertEqual(toml_table[SOURCE_HASH_FIELD], build_definition.source_hash) + self.assertEqual(toml_table[MANIFEST_HASH_FIELD], build_definition.manifest_hash) self.assertEqual(toml_table[ENV_VARS_FIELD], build_definition.env_vars) self.assertEqual(toml_table[ARCHITECTURE_FIELD], build_definition.architecture) def test_layer_build_definition_to_toml_table(self): build_definition = LayerBuildDefinition( - "name", "codeuri", "method", "runtime", ARM64, env_vars={"env_vars": "value"} + "name", + "codeuri", + "method", + ["runtime"], + ARM64, + "source_hash", + "manifest_hash", + env_vars={"env_vars": "value"}, ) build_definition.layer = generate_function() @@ -125,7 +144,8 @@ def test_layer_build_definition_to_toml_table(self): self.assertEqual(toml_table[BUILD_METHOD_FIELD], build_definition.build_method) self.assertEqual(toml_table[COMPATIBLE_RUNTIMES_FIELD], build_definition.compatible_runtimes) self.assertEqual(toml_table[LAYER_FIELD], build_definition.layer.name) - self.assertEqual(toml_table[SOURCE_MD5_FIELD], build_definition.source_md5) + self.assertEqual(toml_table[SOURCE_HASH_FIELD], build_definition.source_hash) + self.assertEqual(toml_table[MANIFEST_HASH_FIELD], build_definition.manifest_hash) self.assertEqual(toml_table[ENV_VARS_FIELD], build_definition.env_vars) self.assertEqual(toml_table[ARCHITECTURE_FIELD], build_definition.architecture) @@ -136,7 +156,8 @@ def test_toml_table_to_function_build_definition(self): toml_table[PACKAGETYPE_FIELD] = ZIP toml_table[METADATA_FIELD] = {"key": "value"} toml_table[FUNCTIONS_FIELD] = ["function1"] - toml_table[SOURCE_MD5_FIELD] = "source_md5" + toml_table[SOURCE_HASH_FIELD] = "source_hash" + toml_table[MANIFEST_HASH_FIELD] = "manifest_hash" toml_table[ENV_VARS_FIELD] = {"env_vars": "value"} toml_table[ARCHITECTURE_FIELD] = X86_64 uuid = str(uuid4()) @@ -149,7 +170,8 @@ def test_toml_table_to_function_build_definition(self): self.assertEqual(build_definition.metadata, toml_table[METADATA_FIELD]) self.assertEqual(build_definition.uuid, uuid) self.assertEqual(build_definition.functions, []) - self.assertEqual(build_definition.source_md5, toml_table[SOURCE_MD5_FIELD]) + self.assertEqual(build_definition.source_hash, toml_table[SOURCE_HASH_FIELD]) + self.assertEqual(build_definition.manifest_hash, toml_table[MANIFEST_HASH_FIELD]) self.assertEqual(build_definition.env_vars, toml_table[ENV_VARS_FIELD]) self.assertEqual(build_definition.architecture, toml_table[ARCHITECTURE_FIELD]) @@ -160,7 +182,8 @@ def test_toml_table_to_layer_build_definition(self): toml_table[BUILD_METHOD_FIELD] = "method" toml_table[COMPATIBLE_RUNTIMES_FIELD] = "runtime" toml_table[COMPATIBLE_RUNTIMES_FIELD] = "layer1" - toml_table[SOURCE_MD5_FIELD] = "source_md5" + toml_table[SOURCE_HASH_FIELD] = "source_hash" + toml_table[MANIFEST_HASH_FIELD] = "manifest_hash" toml_table[ENV_VARS_FIELD] = {"env_vars": "value"} toml_table[ARCHITECTURE_FIELD] = ARM64 uuid = str(uuid4()) @@ -173,7 +196,8 @@ def test_toml_table_to_layer_build_definition(self): self.assertEqual(build_definition.uuid, uuid) self.assertEqual(build_definition.compatible_runtimes, toml_table[COMPATIBLE_RUNTIMES_FIELD]) self.assertEqual(build_definition.layer, None) - self.assertEqual(build_definition.source_md5, toml_table[SOURCE_MD5_FIELD]) + self.assertEqual(build_definition.source_hash, toml_table[SOURCE_HASH_FIELD]) + self.assertEqual(build_definition.manifest_hash, toml_table[MANIFEST_HASH_FIELD]) self.assertEqual(build_definition.env_vars, toml_table[ENV_VARS_FIELD]) self.assertEqual(build_definition.architecture, toml_table[ARCHITECTURE_FIELD]) @@ -187,7 +211,9 @@ def test_minimal_function_build_definition_to_toml_table(self): self.assertEqual(toml_table[RUNTIME_FIELD], build_definition.runtime) self.assertEqual(toml_table[METADATA_FIELD], build_definition.metadata) self.assertEqual(toml_table[FUNCTIONS_FIELD], [f.name for f in build_definition.functions]) - self.assertEqual(toml_table[SOURCE_MD5_FIELD], build_definition.source_md5) + if build_definition.source_hash: + self.assertEqual(toml_table[SOURCE_HASH_FIELD], build_definition.source_hash) + self.assertEqual(toml_table[MANIFEST_HASH_FIELD], build_definition.manifest_hash) self.assertEqual(toml_table[ARCHITECTURE_FIELD], build_definition.architecture) def test_minimal_layer_build_definition_to_toml_table(self): @@ -201,7 +227,9 @@ def test_minimal_layer_build_definition_to_toml_table(self): self.assertEqual(toml_table[BUILD_METHOD_FIELD], build_definition.build_method) self.assertEqual(toml_table[COMPATIBLE_RUNTIMES_FIELD], build_definition.compatible_runtimes) self.assertEqual(toml_table[LAYER_FIELD], build_definition.layer.name) - self.assertEqual(toml_table[SOURCE_MD5_FIELD], build_definition.source_md5) + if build_definition.source_hash: + self.assertEqual(toml_table[SOURCE_HASH_FIELD], build_definition.source_hash) + self.assertEqual(toml_table[MANIFEST_HASH_FIELD], build_definition.manifest_hash) self.assertEqual(toml_table[ARCHITECTURE_FIELD], build_definition.architecture) def test_minimal_toml_table_to_function_build_definition(self): @@ -219,7 +247,8 @@ def test_minimal_toml_table_to_function_build_definition(self): self.assertEqual(build_definition.metadata, {}) self.assertEqual(build_definition.uuid, uuid) self.assertEqual(build_definition.functions, []) - self.assertEqual(build_definition.source_md5, "") + self.assertEqual(build_definition.source_hash, "") + self.assertEqual(build_definition.manifest_hash, "") self.assertEqual(build_definition.env_vars, {}) self.assertEqual(build_definition.architecture, X86_64) @@ -239,7 +268,8 @@ def test_minimal_toml_table_to_layer_build_definition(self): self.assertEqual(build_definition.uuid, uuid) self.assertEqual(build_definition.compatible_runtimes, toml_table[COMPATIBLE_RUNTIMES_FIELD]) self.assertEqual(build_definition.layer, None) - self.assertEqual(build_definition.source_md5, "") + self.assertEqual(build_definition.source_hash, "") + self.assertEqual(build_definition.manifest_hash, "") self.assertEqual(build_definition.env_vars, {}) self.assertEqual(build_definition.architecture, X86_64) @@ -254,7 +284,8 @@ class TestBuildGraph(TestCase): METADATA = {"Test": "hello", "Test2": "world"} UUID = "3c1c254e-cd4b-4d94-8c74-7ab870b36063" LAYER_UUID = "7dnc257e-cd4b-4d94-8c74-7ab870b3abc3" - SOURCE_MD5 = "cae49aa393d669e850bd49869905099d" + SOURCE_HASH = "cae49aa393d669e850bd49869905099d" + MANIFEST_HASH = "rty87gh393d669e850bd49869905099e" ENV_VARS = {"env_vars": "value"} ARCHITECTURE_FIELD = ARM64 LAYER_ARCHITECTURE = X86_64 @@ -264,7 +295,8 @@ class TestBuildGraph(TestCase): [function_build_definitions.{UUID}] codeuri = "{CODEURI}" runtime = "{RUNTIME}" - source_md5 = "{SOURCE_MD5}" + source_hash = "{SOURCE_HASH}" + manifest_hash = "{MANIFEST_HASH}" packagetype = "{ZIP}" architecture = "{ARCHITECTURE_FIELD}" functions = ["HelloWorldPython", "HelloWorldPython2"] @@ -281,7 +313,8 @@ class TestBuildGraph(TestCase): build_method = "{LAYER_RUNTIME}" compatible_runtimes = ["{LAYER_RUNTIME}"] architecture = "{LAYER_ARCHITECTURE}" - source_md5 = "{SOURCE_MD5}" + source_hash = "{SOURCE_HASH}" + manifest_hash = "{MANIFEST_HASH}" layer = "SumLayer" [layer_build_definitions.{LAYER_UUID}.env_vars] env_vars = "{ENV_VARS['env_vars']}" @@ -314,7 +347,8 @@ def test_should_instantiate_first_time_and_update(self): TestBuildGraph.ZIP, TestBuildGraph.ARCHITECTURE_FIELD, TestBuildGraph.METADATA, - TestBuildGraph.SOURCE_MD5, + TestBuildGraph.SOURCE_HASH, + TestBuildGraph.MANIFEST_HASH, TestBuildGraph.ENV_VARS, ) function1 = generate_function( @@ -326,7 +360,8 @@ def test_should_instantiate_first_time_and_update(self): TestBuildGraph.LAYER_CODEURI, TestBuildGraph.LAYER_RUNTIME, [TestBuildGraph.LAYER_RUNTIME], - TestBuildGraph.SOURCE_MD5, + TestBuildGraph.SOURCE_HASH, + TestBuildGraph.MANIFEST_HASH, TestBuildGraph.ENV_VARS, ) layer1 = generate_layer( @@ -370,13 +405,16 @@ def test_should_read_existing_build_graph(self): self.assertEqual(function_build_definition.packagetype, TestBuildGraph.ZIP) self.assertEqual(function_build_definition.architecture, TestBuildGraph.ARCHITECTURE_FIELD) self.assertEqual(function_build_definition.metadata, TestBuildGraph.METADATA) - self.assertEqual(function_build_definition.source_md5, TestBuildGraph.SOURCE_MD5) + self.assertEqual(function_build_definition.source_hash, TestBuildGraph.SOURCE_HASH) + self.assertEqual(function_build_definition.manifest_hash, TestBuildGraph.MANIFEST_HASH) self.assertEqual(function_build_definition.env_vars, TestBuildGraph.ENV_VARS) for layer_build_definition in build_graph.get_layer_build_definitions(): self.assertEqual(layer_build_definition.name, TestBuildGraph.LAYER_NAME) self.assertEqual(layer_build_definition.codeuri, TestBuildGraph.LAYER_CODEURI) self.assertEqual(layer_build_definition.build_method, TestBuildGraph.LAYER_RUNTIME) + self.assertEqual(layer_build_definition.source_hash, TestBuildGraph.SOURCE_HASH) + self.assertEqual(layer_build_definition.manifest_hash, TestBuildGraph.MANIFEST_HASH) self.assertEqual(layer_build_definition.compatible_runtimes, [TestBuildGraph.LAYER_RUNTIME]) self.assertEqual(layer_build_definition.env_vars, TestBuildGraph.ENV_VARS) @@ -396,7 +434,8 @@ def test_functions_should_be_added_existing_build_graph(self): TestBuildGraph.ZIP, TestBuildGraph.ARCHITECTURE_FIELD, TestBuildGraph.METADATA, - TestBuildGraph.SOURCE_MD5, + TestBuildGraph.SOURCE_HASH, + TestBuildGraph.MANIFEST_HASH, TestBuildGraph.ENV_VARS, ) function1 = generate_function( @@ -418,7 +457,8 @@ def test_functions_should_be_added_existing_build_graph(self): TestBuildGraph.ZIP, ARM64, None, - "another_source_md5", + "another_source_hash", + "another_manifest_hash", {"env_vars": "value2"}, ) function2 = generate_function(name="another_function") @@ -445,7 +485,8 @@ def test_layers_should_be_added_existing_build_graph(self): TestBuildGraph.LAYER_RUNTIME, [TestBuildGraph.LAYER_RUNTIME], TestBuildGraph.LAYER_ARCHITECTURE, - TestBuildGraph.SOURCE_MD5, + TestBuildGraph.SOURCE_HASH, + TestBuildGraph.MANIFEST_HASH, TestBuildGraph.ENV_VARS, ) layer1 = generate_layer( @@ -465,7 +506,8 @@ def test_layers_should_be_added_existing_build_graph(self): "another_codeuri", "another_runtime", ["another_runtime"], - "another_source_md5", + "another_source_hash", + "another_manifest_hash", {"env_vars": "value2"}, ) layer2 = generate_layer(arn="arn:aws:lambda:region:account-id:layer:another-layer-name:1") @@ -475,11 +517,152 @@ def test_layers_should_be_added_existing_build_graph(self): self.assertEqual(len(build_definitions), 2) self.assertEqual(build_definitions[1].layer, layer2) + @patch("samcli.lib.build.build_graph.BuildGraph._write_source_hash") + @patch("samcli.lib.build.build_graph.BuildGraph._compare_hash_changes") + def test_update_definition_hash_should_succeed(self, compare_hash_mock, write_hash_mock): + compare_hash_mock.return_value = {"mock": "hash"} + with osutils.mkdir_temp() as temp_base_dir: + build_dir = Path(temp_base_dir, ".aws-sam", "build") + build_dir.mkdir(parents=True) + + build_graph_path = Path(build_dir.parent, "build.toml") + build_graph_path.write_text(TestBuildGraph.BUILD_GRAPH_CONTENTS) + + build_graph = BuildGraph(str(build_dir)) + build_graph.update_definition_hash() + write_hash_mock.assert_called_with({"mock": "hash"}, {"mock": "hash"}) + + def test_compare_hash_changes_should_succeed(self): + with osutils.mkdir_temp() as temp_base_dir: + build_dir = Path(temp_base_dir, ".aws-sam", "build") + build_dir.mkdir(parents=True) + + build_graph_path = Path(build_dir.parent, "build.toml") + build_graph_path.write_text(TestBuildGraph.BUILD_GRAPH_CONTENTS) + + build_graph = BuildGraph(str(build_dir)) + + build_definition = FunctionBuildDefinition( + TestBuildGraph.RUNTIME, + TestBuildGraph.CODEURI, + TestBuildGraph.ZIP, + TestBuildGraph.ARCHITECTURE_FIELD, + TestBuildGraph.METADATA, + TestBuildGraph.SOURCE_HASH, + TestBuildGraph.MANIFEST_HASH, + TestBuildGraph.ENV_VARS, + ) + updated_definition = FunctionBuildDefinition( + TestBuildGraph.RUNTIME, + TestBuildGraph.CODEURI, + TestBuildGraph.ZIP, + TestBuildGraph.ARCHITECTURE_FIELD, + TestBuildGraph.METADATA, + "new_value", + "new_manifest_value", + TestBuildGraph.ENV_VARS, + ) + updated_definition.uuid = build_definition.uuid + + layer_definition = LayerBuildDefinition( + TestBuildGraph.LAYER_NAME, + TestBuildGraph.LAYER_CODEURI, + TestBuildGraph.LAYER_RUNTIME, + [TestBuildGraph.LAYER_RUNTIME], + TestBuildGraph.ARCHITECTURE_FIELD, + TestBuildGraph.SOURCE_HASH, + TestBuildGraph.MANIFEST_HASH, + TestBuildGraph.ENV_VARS, + ) + updated_layer = LayerBuildDefinition( + TestBuildGraph.LAYER_NAME, + TestBuildGraph.LAYER_CODEURI, + TestBuildGraph.LAYER_RUNTIME, + [TestBuildGraph.LAYER_RUNTIME], + TestBuildGraph.ARCHITECTURE_FIELD, + "new_value", + "new_manifest_value", + TestBuildGraph.ENV_VARS, + ) + updated_layer.uuid = layer_definition.uuid + + build_graph._function_build_definitions = [build_definition] + build_graph._layer_build_definitions = [layer_definition] + + function_content = BuildGraph._compare_hash_changes( + [updated_definition], build_graph._function_build_definitions + ) + layer_content = BuildGraph._compare_hash_changes([updated_layer], build_graph._layer_build_definitions) + self.assertEqual(function_content, {build_definition.uuid: ("new_value", "new_manifest_value")}) + self.assertEqual(layer_content, {layer_definition.uuid: ("new_value", "new_manifest_value")}) + + @parameterized.expand( + [ + ("manifest_hash", "manifest_hash", False), + ("manifest_hash", "new_manifest_hash", True), + ] + ) + def test_compare_hash_changes_should_preserve_download_dependencies( + self, old_manifest, new_manifest, download_dependencies + ): + updated_definition = FunctionBuildDefinition("runtime", "codeuri", ZIP, X86_64, {}, manifest_hash=old_manifest) + existing_definition = FunctionBuildDefinition("runtime", "codeuri", ZIP, X86_64, {}, manifest_hash=new_manifest) + BuildGraph._compare_hash_changes([updated_definition], [existing_definition]) + self.assertEqual(existing_definition.download_dependencies, download_dependencies) + + def test_write_source_hash_should_succeed(self): + with osutils.mkdir_temp() as temp_base_dir: + build_dir = Path(temp_base_dir, ".aws-sam", "build") + build_dir.mkdir(parents=True) + + build_graph_path = Path(build_dir.parent, "build.toml") + build_graph_path.write_text(TestBuildGraph.BUILD_GRAPH_CONTENTS) + + build_graph = BuildGraph(str(build_dir)) + + build_graph._write_source_hash( + {TestBuildGraph.UUID: BuildHashingInformation("new_value", "new_manifest_value")}, + {TestBuildGraph.LAYER_UUID: BuildHashingInformation("new_value", "new_manifest_value")}, + ) + + txt = build_graph_path.read_text() + document = cast(Dict, tomlkit.loads(txt)) + + self.assertEqual( + document["function_build_definitions"][TestBuildGraph.UUID][SOURCE_HASH_FIELD], "new_value" + ) + self.assertEqual( + document["function_build_definitions"][TestBuildGraph.UUID][MANIFEST_HASH_FIELD], "new_manifest_value" + ) + self.assertEqual( + document["layer_build_definitions"][TestBuildGraph.LAYER_UUID][SOURCE_HASH_FIELD], "new_value" + ) + self.assertEqual( + document["layer_build_definitions"][TestBuildGraph.LAYER_UUID][MANIFEST_HASH_FIELD], + "new_manifest_value", + ) + + def test_empty_get_function_build_definition_with_logical_id(self): + build_graph = BuildGraph("build_dir") + self.assertIsNone(build_graph.get_function_build_definition_with_logical_id("function_logical_id")) + + def test_get_function_build_definition_with_logical_id(self): + build_graph = BuildGraph("build_dir") + logical_id = "function_logical_id" + function = Mock() + function.name = logical_id + function_build_definition = Mock(functions=[function]) + build_graph._function_build_definitions = [function_build_definition] + + self.assertEqual( + build_graph.get_function_build_definition_with_logical_id(logical_id), function_build_definition + ) + class TestBuildDefinition(TestCase): def test_single_function_should_return_function_and_handler_name(self): build_definition = FunctionBuildDefinition( - "runtime", "codeuri", ZIP, X86_64, "metadata", "source_md5", {"env_vars": "value"} + "runtime", "codeuri", ZIP, X86_64, "metadata", "source_hash", "manifest_hash", {"env_vars": "value"} ) build_definition.add_function(generate_function()) self.assertEqual(build_definition.get_handler_name(), "handler") @@ -487,24 +670,28 @@ def test_single_function_should_return_function_and_handler_name(self): def test_no_function_should_raise_exception(self): build_definition = FunctionBuildDefinition( - "runtime", "codeuri", ZIP, X86_64, "metadata", "source_md5", {"env_vars": "value"} + "runtime", "codeuri", ZIP, X86_64, "metadata", "source_hash", "manifest_hash", {"env_vars": "value"} ) self.assertRaises(InvalidBuildGraphException, build_definition.get_handler_name) self.assertRaises(InvalidBuildGraphException, build_definition.get_function_name) def test_same_runtime_codeuri_metadata_should_reflect_as_same_object(self): - build_definition1 = FunctionBuildDefinition("runtime", "codeuri", ZIP, ARM64, {"key": "value"}, "source_md5") - build_definition2 = FunctionBuildDefinition("runtime", "codeuri", ZIP, ARM64, {"key": "value"}, "source_md5") + build_definition1 = FunctionBuildDefinition( + "runtime", "codeuri", ZIP, ARM64, {"key": "value"}, "source_hash", "manifest_hash" + ) + build_definition2 = FunctionBuildDefinition( + "runtime", "codeuri", ZIP, ARM64, {"key": "value"}, "source_hash", "manifest_hash" + ) self.assertEqual(build_definition1, build_definition2) def test_same_env_vars_reflect_as_same_object(self): build_definition1 = FunctionBuildDefinition( - "runtime", "codeuri", ZIP, X86_64, {"key": "value"}, "source_md5", {"env_vars": "value"} + "runtime", "codeuri", ZIP, X86_64, {"key": "value"}, "source_hash", "manifest_hash", {"env_vars": "value"} ) build_definition2 = FunctionBuildDefinition( - "runtime", "codeuri", ZIP, X86_64, {"key": "value"}, "source_md5", {"env_vars": "value"} + "runtime", "codeuri", ZIP, X86_64, {"key": "value"}, "source_hash", "manifest_hash", {"env_vars": "value"} ) self.assertEqual(build_definition1, build_definition2) @@ -515,50 +702,50 @@ def test_same_env_vars_reflect_as_same_object(self): "runtime", "codeuri", ({"key": "value"}), - "source_md5", + "source_hash", "runtime", "codeuri", ({"key": "different_value"}), - "source_md5", + "source_hash", ), ( "runtime", "codeuri", ({"key": "value"}), - "source_md5", + "source_hash", "different_runtime", "codeuri", ({"key": "value"}), - "source_md5", + "source_hash", ), ( "runtime", "codeuri", ({"key": "value"}), - "source_md5", + "source_hash", "runtime", "different_codeuri", ({"key": "value"}), - "source_md5", + "source_hash", ), # custom build method with Makefile definition should always be identified as different ( "runtime", "codeuri", ({"BuildMethod": "makefile"}), - "source_md5", + "source_hash", "runtime", "codeuri", ({"BuildMethod": "makefile"}), - "source_md5", + "source_hash", ), ] ) def test_different_runtime_codeuri_metadata_should_not_reflect_as_same_object( - self, runtime1, codeuri1, metadata1, source_md5_1, runtime2, codeuri2, metadata2, source_md5_2 + self, runtime1, codeuri1, metadata1, source_hash_1, runtime2, codeuri2, metadata2, source_hash_2 ): - build_definition1 = FunctionBuildDefinition(runtime1, codeuri1, ZIP, ARM64, metadata1, source_md5_1) - build_definition2 = FunctionBuildDefinition(runtime2, codeuri2, ZIP, ARM64, metadata2, source_md5_2) + build_definition1 = FunctionBuildDefinition(runtime1, codeuri1, ZIP, ARM64, metadata1, source_hash_1) + build_definition2 = FunctionBuildDefinition(runtime2, codeuri2, ZIP, ARM64, metadata2, source_hash_2) self.assertNotEqual(build_definition1, build_definition2) @@ -574,21 +761,25 @@ def test_different_architecture_should_not_reflect_as_same_object(self): def test_different_env_vars_should_not_reflect_as_same_object(self): build_definition1 = FunctionBuildDefinition( - "runtime", "codeuri", ZIP, ARM64, {"key": "value"}, "source_md5", {"env_vars": "value1"} + "runtime", "codeuri", ZIP, ARM64, {"key": "value"}, "source_hash", "manifest_hash", {"env_vars": "value1"} ) build_definition2 = FunctionBuildDefinition( - "runtime", "codeuri", ZIP, ARM64, {"key": "value"}, "source_md5", {"env_vars": "value2"} + "runtime", "codeuri", ZIP, ARM64, {"key": "value"}, "source_hash", "manifest_hash", {"env_vars": "value2"} ) self.assertNotEqual(build_definition1, build_definition2) def test_euqality_with_another_object(self): - build_definition = FunctionBuildDefinition("runtime", "codeuri", ZIP, X86_64, None, "source_md5") + build_definition = FunctionBuildDefinition( + "runtime", "codeuri", ZIP, X86_64, None, "source_hash", "manifest_hash" + ) self.assertNotEqual(build_definition, {}) def test_str_representation(self): - build_definition = FunctionBuildDefinition("runtime", "codeuri", ZIP, ARM64, None, "source_md5") + build_definition = FunctionBuildDefinition( + "runtime", "codeuri", ZIP, ARM64, None, "source_hash", "manifest_hash" + ) self.assertEqual( str(build_definition), - f"BuildDefinition(runtime, codeuri, Zip, arm64, source_md5, {build_definition.uuid}, {{}}, {{}}, [])", + f"BuildDefinition(runtime, codeuri, Zip, source_hash, {build_definition.uuid}, {{}}, {{}}, arm64, [])", ) diff --git a/tests/unit/lib/build_module/test_build_strategy.py b/tests/unit/lib/build_module/test_build_strategy.py index 1e93a16aea..ea247d4264 100644 --- a/tests/unit/lib/build_module/test_build_strategy.py +++ b/tests/unit/lib/build_module/test_build_strategy.py @@ -1,15 +1,19 @@ +from copy import deepcopy from unittest import TestCase from unittest.mock import Mock, patch, MagicMock, call, ANY -from samcli.lib.utils.architecture import X86_64, ARM64 +from parameterized import parameterized -from samcli.commands.build.exceptions import MissingBuildMethodException +from samcli.lib.utils.architecture import X86_64, ARM64 +from samcli.lib.build.exceptions import MissingBuildMethodException from samcli.lib.build.build_graph import BuildGraph, FunctionBuildDefinition, LayerBuildDefinition from samcli.lib.build.build_strategy import ( ParallelBuildStrategy, BuildStrategy, DefaultBuildStrategy, CachedBuildStrategy, + CachedOrIncrementalBuildStrategyWrapper, + IncrementalBuildStrategy, ) from samcli.lib.utils import osutils from pathlib import Path @@ -158,6 +162,8 @@ def test_build_layers_and_functions(self, mock_copy_tree): self.function_build_definition1.get_build_dir(given_build_dir), self.function_build_definition1.metadata, self.function_build_definition1.env_vars, + self.function_build_definition1.dependencies_dir, + True, ), call( self.function_build_definition2.get_function_name(), @@ -169,6 +175,8 @@ def test_build_layers_and_functions(self, mock_copy_tree): self.function_build_definition2.get_build_dir(given_build_dir), self.function_build_definition2.metadata, self.function_build_definition2.env_vars, + self.function_build_definition2.dependencies_dir, + True, ), ] ) @@ -183,7 +191,9 @@ def test_build_layers_and_functions(self, mock_copy_tree): self.layer1.compatible_runtimes, self.layer1.build_architecture, self.layer1.get_build_dir(given_build_dir), - self.function_build_definition1.env_vars, + self.layer_build_definition1.env_vars, + self.layer_build_definition1.dependencies_dir, + True, ), call( self.layer2.name, @@ -192,7 +202,9 @@ def test_build_layers_and_functions(self, mock_copy_tree): self.layer2.compatible_runtimes, self.layer2.build_architecture, self.layer2.get_build_dir(given_build_dir), - self.function_build_definition2.env_vars, + self.layer_build_definition2.env_vars, + self.layer_build_definition2.dependencies_dir, + True, ), ] ) @@ -228,7 +240,11 @@ def test_build_single_function_definition_image_functions_with_same_metadata(sel # since they have the same metadata, they are put into the same build_definition. build_definition.functions = [function1, function2] - result = default_build_strategy.build_single_function_definition(build_definition) + with patch("samcli.lib.build.build_strategy.deepcopy", wraps=deepcopy) as patched_deepcopy: + result = default_build_strategy.build_single_function_definition(build_definition) + + patched_deepcopy.assert_called_with(build_definition.env_vars) + # both of the function name should show up in results self.assertEqual(result, {"Function": built_image, "Function2": built_image}) @@ -237,7 +253,7 @@ class CachedBuildStrategyTest(BuildStrategyBaseTest): CODEURI = "hello_world_python/" RUNTIME = "python3.8" FUNCTION_UUID = "3c1c254e-cd4b-4d94-8c74-7ab870b36063" - SOURCE_MD5 = "cae49aa393d669e850bd49869905099d" + SOURCE_HASH = "cae49aa393d669e850bd49869905099d" LAYER_UUID = "761ce752-d1c8-4e07-86a0-f64778cdd108" LAYER_METHOD = "nodejs12.x" @@ -247,7 +263,7 @@ class CachedBuildStrategyTest(BuildStrategyBaseTest): codeuri = "{CODEURI}" packagetype = "{ZIP}" runtime = "{RUNTIME}" - source_md5 = "{SOURCE_MD5}" + source_hash = "{SOURCE_HASH}" functions = ["HelloWorldPython", "HelloWorldPython2"] [layer_build_definitions] @@ -256,7 +272,7 @@ class CachedBuildStrategyTest(BuildStrategyBaseTest): codeuri = "sum_layer/" build_method = "nodejs12.x" compatible_runtimes = ["nodejs12.x"] - source_md5 = "{SOURCE_MD5}" + source_hash = "{SOURCE_HASH}" layer = "SumLayer" """ @@ -273,7 +289,7 @@ def test_build_call(self, mock_layer_build, mock_function_build, mock_rmtree, mo self.build_graph, given_build_dir, given_build_function, given_build_layer ) cache_build_strategy = CachedBuildStrategy( - self.build_graph, default_build_strategy, "base_dir", given_build_dir, "cache_dir", True + self.build_graph, default_build_strategy, "base_dir", given_build_dir, "cache_dir" ) cache_build_strategy.build() mock_function_build.assert_called() @@ -283,7 +299,6 @@ def test_build_call(self, mock_layer_build, mock_function_build, mock_rmtree, mo @patch("samcli.lib.build.build_strategy.pathlib.Path.exists") @patch("samcli.lib.build.build_strategy.dir_checksum") def test_if_cached_valid_when_build_single_function_definition(self, dir_checksum_mock, exists_mock, copytree_mock): - pass with osutils.mkdir_temp() as temp_base_dir: build_dir = Path(temp_base_dir, ".aws-sam", "build") build_dir.mkdir(parents=True) @@ -291,13 +306,13 @@ def test_if_cached_valid_when_build_single_function_definition(self, dir_checksu cache_dir.mkdir(parents=True) exists_mock.return_value = True - dir_checksum_mock.return_value = CachedBuildStrategyTest.SOURCE_MD5 + dir_checksum_mock.return_value = CachedBuildStrategyTest.SOURCE_HASH build_graph_path = Path(build_dir.parent, "build.toml") build_graph_path.write_text(CachedBuildStrategyTest.BUILD_GRAPH_CONTENTS) build_graph = BuildGraph(str(build_dir)) cached_build_strategy = CachedBuildStrategy( - build_graph, DefaultBuildStrategy, temp_base_dir, build_dir, cache_dir, True + build_graph, DefaultBuildStrategy, temp_base_dir, build_dir, cache_dir ) func1 = Mock() func1.name = "func1_name" @@ -336,7 +351,7 @@ def test_if_cached_invalid_with_no_cached_folder(self, build_layer_mock, build_f build_graph_path.write_text(CachedBuildStrategyTest.BUILD_GRAPH_CONTENTS) build_graph = BuildGraph(str(build_dir)) cached_build_strategy = CachedBuildStrategy( - build_graph, DefaultBuildStrategy, temp_base_dir, build_dir, cache_dir, True + build_graph, DefaultBuildStrategy, temp_base_dir, build_dir, cache_dir ) cached_build_strategy.build_single_function_definition(build_graph.get_function_build_definitions()[0]) cached_build_strategy.build_single_layer_definition(build_graph.get_layer_build_definitions()[0]) @@ -354,7 +369,7 @@ def test_redundant_cached_should_be_clean(self): redundant_cache_folder = Path(cache_dir, "redundant") redundant_cache_folder.mkdir(parents=True) - cached_build_strategy = CachedBuildStrategy(build_graph, Mock(), temp_base_dir, build_dir, cache_dir, True) + cached_build_strategy = CachedBuildStrategy(build_graph, Mock(), temp_base_dir, build_dir, cache_dir) cached_build_strategy._clean_redundant_cached() self.assertTrue(not redundant_cache_folder.exists()) @@ -435,3 +450,143 @@ def test_given_delegate_strategy_it_should_call_delegated_build_methods(self): call(self.layer_build_definition2), ] ) + + +@patch("samcli.lib.build.build_strategy.DependencyHashGenerator") +class TestIncrementalBuildStrategy(TestCase): + def setUp(self): + self.build_function = Mock() + self.build_layer = Mock() + self.build_graph = Mock() + self.delegate_build_strategy = DefaultBuildStrategy( + self.build_graph, Mock(), self.build_function, self.build_layer + ) + self.build_strategy = IncrementalBuildStrategy( + self.build_graph, + self.delegate_build_strategy, + Mock(), + Mock(), + ) + + def test_assert_incremental_build_function(self, patched_manifest_hash): + same_hash = "same_hash" + patched_manifest_hash_instance = Mock(hash=same_hash) + patched_manifest_hash.return_value = patched_manifest_hash_instance + + given_function_build_def = Mock(manifest_hash=same_hash, functions=[Mock()]) + self.build_graph.get_function_build_definitions.return_value = [given_function_build_def] + self.build_graph.get_layer_build_definitions.return_value = [] + + self.build_strategy.build() + self.build_function.assert_called_with(ANY, ANY, ANY, ANY, ANY, ANY, ANY, ANY, ANY, ANY, False) + + def test_assert_incremental_build_layer(self, patched_manifest_hash): + same_hash = "same_hash" + patched_manifest_hash_instance = Mock(hash=same_hash) + patched_manifest_hash.return_value = patched_manifest_hash_instance + + given_layer_build_def = Mock(manifest_hash=same_hash, functions=[Mock()]) + self.build_graph.get_function_build_definitions.return_value = [] + self.build_graph.get_layer_build_definitions.return_value = [given_layer_build_def] + + self.build_strategy.build() + self.build_layer.assert_called_with(ANY, ANY, ANY, ANY, ANY, ANY, ANY, ANY, False) + + +@patch("samcli.lib.build.build_graph.BuildGraph._write") +@patch("samcli.lib.build.build_graph.BuildGraph._read") +class TestCachedOrIncrementalBuildStrategyWrapper(TestCase): + def setUp(self) -> None: + self.build_graph = BuildGraph("build/graph/location") + + self.build_strategy = CachedOrIncrementalBuildStrategyWrapper( + self.build_graph, + Mock(), + "base_dir", + "build_dir", + "cache_dir", + "manifest_path_override", + False, + ) + + @parameterized.expand( + [ + ("python3.7", True), + ("nodejs12.x", True), + ("ruby2.7", True), + ("python3.7", False), + ] + ) + @patch("samcli.lib.build.build_strategy.is_experimental_enabled") + def test_will_call_incremental_build_strategy( + self, mocked_read, mocked_write, runtime, experimental_enabled, patched_experimental + ): + patched_experimental.return_value = experimental_enabled + build_definition = FunctionBuildDefinition(runtime, "codeuri", "packate_type", X86_64, {}) + self.build_graph.put_function_build_definition(build_definition, Mock()) + with patch.object( + self.build_strategy, "_incremental_build_strategy" + ) as patched_incremental_build_strategy, patch.object( + self.build_strategy, "_cached_build_strategy" + ) as patched_cached_build_strategy: + self.build_strategy.build() + + if experimental_enabled: + patched_incremental_build_strategy.build_single_function_definition.assert_called_with(build_definition) + patched_cached_build_strategy.assert_not_called() + else: + patched_cached_build_strategy.build_single_function_definition.assert_called_with(build_definition) + patched_incremental_build_strategy.assert_not_called() + + @parameterized.expand( + [ + "dotnetcore2.1", + "go1.x", + "java11", + ] + ) + def test_will_call_cached_build_strategy(self, mocked_read, mocked_write, runtime): + build_definition = FunctionBuildDefinition(runtime, "codeuri", "packate_type", X86_64, {}) + self.build_graph.put_function_build_definition(build_definition, Mock()) + with patch.object( + self.build_strategy, "_incremental_build_strategy" + ) as patched_incremental_build_strategy, patch.object( + self.build_strategy, "_cached_build_strategy" + ) as patched_cached_build_strategy: + self.build_strategy.build() + + patched_cached_build_strategy.build_single_function_definition.assert_called_with(build_definition) + patched_incremental_build_strategy.assert_not_called() + + @parameterized.expand([(True,), (False,)]) + @patch("samcli.lib.build.build_strategy.CachedBuildStrategy._clean_redundant_cached") + @patch("samcli.lib.build.build_strategy.IncrementalBuildStrategy._clean_redundant_dependencies") + def test_exit_build_strategy_for_specific_resource( + self, is_building_specific_resource, clean_cache_mock, clean_dep_mock, mocked_read, mocked_write + ): + with osutils.mkdir_temp() as temp_base_dir: + build_dir = Path(temp_base_dir, ".aws-sam", "build") + build_dir.mkdir(parents=True) + cache_dir = Path(temp_base_dir, ".aws-sam", "cache") + cache_dir.mkdir(parents=True) + + mocked_build_graph = Mock() + mocked_build_graph.get_layer_build_definitions.return_value = [] + mocked_build_graph.get_function_build_definitions.return_value = [] + + cached_build_strategy = CachedOrIncrementalBuildStrategyWrapper( + mocked_build_graph, Mock(), temp_base_dir, build_dir, cache_dir, None, is_building_specific_resource + ) + + cached_build_strategy.build() + + if is_building_specific_resource: + mocked_build_graph.update_definition_hash.assert_called_once() + mocked_build_graph.clean_redundant_definitions_and_update.assert_not_called() + clean_cache_mock.assert_not_called() + clean_dep_mock.assert_not_called() + else: + mocked_build_graph.update_definition_hash.assert_not_called() + mocked_build_graph.clean_redundant_definitions_and_update.assert_called_once() + clean_cache_mock.assert_called_once() + clean_dep_mock.assert_called_once() diff --git a/tests/unit/lib/build_module/test_dependency_hash_generator.py b/tests/unit/lib/build_module/test_dependency_hash_generator.py new file mode 100644 index 0000000000..b3cf3d1c41 --- /dev/null +++ b/tests/unit/lib/build_module/test_dependency_hash_generator.py @@ -0,0 +1,86 @@ +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from samcli.lib.build.dependency_hash_generator import DependencyHashGenerator + + +class TestDependencyHashGenerator(TestCase): + def setUp(self): + self.get_workflow_config_patch = patch("samcli.lib.build.dependency_hash_generator.get_workflow_config") + self.get_workflow_config_mock = self.get_workflow_config_patch.start() + self.get_workflow_config_mock.return_value.manifest_name = "manifest_file" + + self.file_checksum_patch = patch("samcli.lib.build.dependency_hash_generator.file_checksum") + self.file_checksum_mock = self.file_checksum_patch.start() + self.file_checksum_mock.return_value = "checksum" + + def tearDown(self): + self.get_workflow_config_patch.stop() + self.file_checksum_patch.stop() + + @patch("samcli.lib.build.dependency_hash_generator.DependencyHashGenerator._calculate_dependency_hash") + @patch("samcli.lib.build.dependency_hash_generator.pathlib.Path") + def test_init_and_properties(self, path_mock, calculate_hash_mock): + path_mock.return_value.resolve.return_value.__str__.return_value = "code_dir" + calculate_hash_mock.return_value = "dependency_hash" + self.generator = DependencyHashGenerator("code_uri", "base_dir", "runtime") + self.assertEqual(self.generator._code_uri, "code_uri") + self.assertEqual(self.generator._base_dir, "base_dir") + self.assertEqual(self.generator._code_dir, "code_dir") + self.assertEqual(self.generator._runtime, "runtime") + self.assertEqual(self.generator.hash, "dependency_hash") + + path_mock.assert_called_once_with("base_dir", "code_uri") + + @patch("samcli.lib.build.dependency_hash_generator.pathlib.Path") + def test_calculate_manifest_hash(self, path_mock): + code_dir_mock = MagicMock() + code_dir_mock.resolve.return_value.__str__.return_value = "code_dir" + manifest_path_mock = MagicMock() + manifest_path_mock.resolve.return_value.__str__.return_value = "manifest_path" + manifest_path_mock.resolve.return_value.is_file.return_value = True + path_mock.side_effect = [code_dir_mock, manifest_path_mock] + + self.generator = DependencyHashGenerator("code_uri", "base_dir", "runtime") + hash = self.generator.hash + self.file_checksum_mock.assert_called_once_with("manifest_path", hash_generator=None) + self.assertEqual(hash, "checksum") + + path_mock.assert_any_call("base_dir", "code_uri") + path_mock.assert_any_call("code_dir", "manifest_file") + + @patch("samcli.lib.build.dependency_hash_generator.pathlib.Path") + def test_calculate_manifest_hash_missing_file(self, path_mock): + code_dir_mock = MagicMock() + code_dir_mock.resolve.return_value.__str__.return_value = "code_dir" + manifest_path_mock = MagicMock() + manifest_path_mock.resolve.return_value.__str__.return_value = "manifest_path" + manifest_path_mock.resolve.return_value.is_file.return_value = False + path_mock.side_effect = [code_dir_mock, manifest_path_mock] + + self.generator = DependencyHashGenerator("code_uri", "base_dir", "runtime") + self.file_checksum_mock.assert_not_called() + self.assertEqual(self.generator.hash, None) + + path_mock.assert_any_call("base_dir", "code_uri") + path_mock.assert_any_call("code_dir", "manifest_file") + + @patch("samcli.lib.build.dependency_hash_generator.pathlib.Path") + def test_calculate_manifest_hash_manifest_override(self, path_mock): + code_dir_mock = MagicMock() + code_dir_mock.resolve.return_value.__str__.return_value = "code_dir" + manifest_path_mock = MagicMock() + manifest_path_mock.resolve.return_value.__str__.return_value = "manifest_path" + manifest_path_mock.resolve.return_value.is_file.return_value = True + path_mock.side_effect = [code_dir_mock, manifest_path_mock] + + self.generator = DependencyHashGenerator( + "code_uri", "base_dir", "runtime", manifest_path_override="manifest_override" + ) + hash = self.generator.hash + self.get_workflow_config_mock.assert_not_called() + self.file_checksum_mock.assert_called_once_with("manifest_path", hash_generator=None) + self.assertEqual(hash, "checksum") + + path_mock.assert_any_call("base_dir", "code_uri") + path_mock.assert_any_call("code_dir", "manifest_override") diff --git a/tests/unit/lib/deploy/test_deployer.py b/tests/unit/lib/deploy/test_deployer.py index 667e52df96..72d5da3b10 100644 --- a/tests/unit/lib/deploy/test_deployer.py +++ b/tests/unit/lib/deploy/test_deployer.py @@ -1,3 +1,4 @@ +from logging import captureWarnings import uuid import time from datetime import datetime, timedelta @@ -754,3 +755,107 @@ def test_wait_for_execute_with_outputs(self, patched_time): self.deployer.get_stack_outputs = MagicMock(return_value=outputs["Stacks"][0]["Outputs"]) self.deployer.wait_for_execute("test", "CREATE", False) self.assertEqual(self.deployer._display_stack_outputs.call_count, 1) + + def test_sync_update_stack(self): + self.deployer.has_stack = MagicMock(return_value=True) + self.deployer.wait_for_execute = MagicMock() + self.deployer.sync( + stack_name="test", + cfn_template=" ", + parameter_values=[ + {"ParameterKey": "a", "ParameterValue": "b"}, + ], + capabilities=["CAPABILITY_IAM"], + role_arn="role-arn", + notification_arns=[], + s3_uploader=S3Uploader(s3_client=self.s3_client, bucket_name="test_bucket"), + tags={"unit": "true"}, + ) + + self.assertEqual(self.deployer._client.update_stack.call_count, 1) + self.deployer._client.update_stack.assert_called_with( + Capabilities=["CAPABILITY_IAM"], + NotificationARNs=[], + Parameters=[{"ParameterKey": "a", "ParameterValue": "b"}], + RoleARN="role-arn", + StackName="test", + Tags={"unit": "true"}, + TemplateURL=ANY, + ) + + def test_sync_update_stack_exception(self): + self.deployer.has_stack = MagicMock(return_value=True) + self.deployer.wait_for_execute = MagicMock() + self.deployer._client.update_stack = MagicMock(side_effect=Exception) + with self.assertRaises(DeployFailedError): + self.deployer.sync( + stack_name="test", + cfn_template=" ", + parameter_values=[ + {"ParameterKey": "a", "ParameterValue": "b"}, + ], + capabilities=["CAPABILITY_IAM"], + role_arn="role-arn", + notification_arns=[], + s3_uploader=S3Uploader(s3_client=self.s3_client, bucket_name="test_bucket"), + tags={"unit": "true"}, + ) + + def test_sync_create_stack(self): + self.deployer.has_stack = MagicMock(return_value=False) + self.deployer.wait_for_execute = MagicMock() + self.deployer.sync( + stack_name="test", + cfn_template=" ", + parameter_values=[ + {"ParameterKey": "a", "ParameterValue": "b"}, + ], + capabilities=["CAPABILITY_IAM"], + role_arn="role-arn", + notification_arns=[], + s3_uploader=S3Uploader(s3_client=self.s3_client, bucket_name="test_bucket"), + tags={"unit": "true"}, + ) + + self.assertEqual(self.deployer._client.create_stack.call_count, 1) + self.deployer._client.create_stack.assert_called_with( + Capabilities=["CAPABILITY_IAM"], + NotificationARNs=[], + Parameters=[{"ParameterKey": "a", "ParameterValue": "b"}], + RoleARN="role-arn", + StackName="test", + Tags={"unit": "true"}, + TemplateURL=ANY, + ) + + def test_sync_create_stack_exception(self): + self.deployer.has_stack = MagicMock(return_value=False) + self.deployer.wait_for_execute = MagicMock() + self.deployer._client.create_stack = MagicMock(side_effect=Exception) + with self.assertRaises(DeployFailedError): + self.deployer.sync( + stack_name="test", + cfn_template=" ", + parameter_values=[ + {"ParameterKey": "a", "ParameterValue": "b"}, + ], + capabilities=["CAPABILITY_IAM"], + role_arn="role-arn", + notification_arns=[], + s3_uploader=S3Uploader(s3_client=self.s3_client, bucket_name="test_bucket"), + tags={"unit": "true"}, + ) + + def test_process_kwargs(self): + kwargs = {"Capabilities": []} + capabilities = ["CAPABILITY_IAM"] + role_arn = "role-arn" + notification_arns = ["arn"] + + expected = { + "Capabilities": ["CAPABILITY_IAM"], + "RoleARN": "role-arn", + "NotificationARNs": ["arn"], + } + result = self.deployer._process_kwargs(kwargs, None, capabilities, role_arn, notification_arns) + self.assertEqual(expected, result) diff --git a/tests/unit/lib/observability/cw_logs/test_cw_log_formatters.py b/tests/unit/lib/observability/cw_logs/test_cw_log_formatters.py index f864ff1fe7..652615a81e 100644 --- a/tests/unit/lib/observability/cw_logs/test_cw_log_formatters.py +++ b/tests/unit/lib/observability/cw_logs/test_cw_log_formatters.py @@ -10,6 +10,8 @@ CWColorizeErrorsFormatter, CWKeywordHighlighterFormatter, CWJsonFormatter, + CWAddNewLineIfItDoesntExist, + CWLogEventJSONMapper, ) @@ -118,3 +120,47 @@ def test_ignore_non_json(self, input_msg): result = self.formatter.map(event) self.assertEqual(result.message, input_msg) + + +class TestCWAddNewLineIfItDoesntExist(TestCase): + def setUp(self) -> None: + self.formatter = CWAddNewLineIfItDoesntExist() + + @parameterized.expand( + [ + (CWLogEvent("log_group", {"message": "input"}),), + (CWLogEvent("log_group", {"message": "input\n"}),), + ] + ) + def test_cw_log_event(self, log_event): + mapped_event = self.formatter.map(log_event) + self.assertEqual(mapped_event.message, "input\n") + + @parameterized.expand( + [ + ("input",), + ("input\n",), + ] + ) + def test_str_event(self, str_event): + mapped_event = self.formatter.map(str_event) + self.assertEqual(mapped_event, "input\n") + + @parameterized.expand( + [ + ({"some": "dict"},), + (5,), + ] + ) + def test_other_events(self, event): + mapped_event = self.formatter.map(event) + self.assertEqual(mapped_event, event) + + +class TestCWLogEventJSONMapper(TestCase): + def test_mapper(self): + given_event = CWLogEvent("log_group", {"message": "input"}) + mapper = CWLogEventJSONMapper() + + mapped_event = mapper.map(given_event) + self.assertEqual(mapped_event.message, json.dumps(given_event.event)) diff --git a/tests/unit/lib/observability/cw_logs/test_cw_log_group_provider.py b/tests/unit/lib/observability/cw_logs/test_cw_log_group_provider.py index 295ad6d898..893fa501a5 100644 --- a/tests/unit/lib/observability/cw_logs/test_cw_log_group_provider.py +++ b/tests/unit/lib/observability/cw_logs/test_cw_log_group_provider.py @@ -1,11 +1,72 @@ from unittest import TestCase +from unittest.mock import Mock, ANY, patch +from samcli.commands._utils.experimental import set_experimental, ExperimentalFlag from samcli.lib.observability.cw_logs.cw_log_group_provider import LogGroupProvider +@patch("samcli.commands._utils.experimental._update_experimental_context") class TestLogGroupProvider_for_lambda_function(TestCase): - def test_must_return_log_group_name(self): - expected = "/aws/lambda/myfunctionname" - result = LogGroupProvider.for_lambda_function("myfunctionname") + def setUp(self) -> None: + set_experimental(config_entry=ExperimentalFlag.Accelerate, enabled=True) + + def test_must_return_log_group_name(self, patched_update_experimental_context): + expected = "/aws/lambda/my_function_name" + result = LogGroupProvider.for_lambda_function("my_function_name") + + self.assertEqual(expected, result) + + def test_rest_api_log_group_name(self, patched_update_experimental_context): + expected = "API-Gateway-Execution-Logs_my_function_name/Prod" + result = LogGroupProvider.for_resource(Mock(), "AWS::ApiGateway::RestApi", "my_function_name") self.assertEqual(expected, result) + + def test_http_api_log_group_name(self, patched_update_experimental_context): + given_client_provider = Mock() + given_client_provider(ANY).get_stage.return_value = { + "AccessLogSettings": {"DestinationArn": "test:my_log_group"} + } + expected = "my_log_group" + result = LogGroupProvider.for_resource(given_client_provider, "AWS::ApiGatewayV2::Api", "my_function_name") + + self.assertEqual(expected, result) + + def test_http_api_log_group_name_not_exist(self, patched_update_experimental_context): + given_client_provider = Mock() + given_client_provider(ANY).get_stage.return_value = {} + result = LogGroupProvider.for_resource(given_client_provider, "AWS::ApiGatewayV2::Api", "my_function_name") + + self.assertIsNone(result) + + def test_step_functions(self, patched_update_experimental_context): + given_client_provider = Mock() + given_cw_log_group_name = "sam-app-logs-command-test-MyStateMachineLogGroup-ucwMaQpNBJTD" + given_client_provider(ANY).describe_state_machine.return_value = { + "loggingConfiguration": { + "destinations": [ + { + "cloudWatchLogsLogGroup": { + "logGroupArn": f"arn:aws:logs:us-west-2:694866504768:log-group:{given_cw_log_group_name}:*" + } + } + ] + } + } + + result = LogGroupProvider.for_resource( + given_client_provider, "AWS::StepFunctions::StateMachine", "my_state_machine" + ) + + self.assertIsNotNone(result) + self.assertEqual(result, given_cw_log_group_name) + + def test_invalid_step_functions(self, patched_update_experimental_context): + given_client_provider = Mock() + given_client_provider(ANY).describe_state_machine.return_value = {"loggingConfiguration": {"destinations": []}} + + result = LogGroupProvider.for_resource( + given_client_provider, "AWS::StepFunctions::StateMachine", "my_state_machine" + ) + + self.assertIsNone(result) diff --git a/tests/unit/lib/observability/cw_logs/test_cw_log_puller.py b/tests/unit/lib/observability/cw_logs/test_cw_log_puller.py index 98f4e6d3de..d71e2585b4 100644 --- a/tests/unit/lib/observability/cw_logs/test_cw_log_puller.py +++ b/tests/unit/lib/observability/cw_logs/test_cw_log_puller.py @@ -100,6 +100,37 @@ def test_must_fetch_logs_with_all_params(self): for event in self.expected_events: self.assertIn(event, call_args) + @patch("samcli.lib.observability.cw_logs.cw_log_puller.LOG") + def test_must_print_resource_not_found_only_once(self, patched_log): + pattern = "foobar" + start = datetime.utcnow() + end = datetime.utcnow() + + expected_params = { + "logGroupName": self.log_group_name, + "interleaved": True, + "startTime": to_timestamp(start), + "endTime": to_timestamp(end), + "filterPattern": pattern, + } + + self.client_stubber.add_client_error( + "filter_log_events", expected_params=expected_params, service_error_code="ResourceNotFoundException" + ) + self.client_stubber.add_client_error( + "filter_log_events", expected_params=expected_params, service_error_code="ResourceNotFoundException" + ) + self.client_stubber.add_response("filter_log_events", self.mock_api_response, expected_params) + + with self.client_stubber: + self.assertFalse(self.fetcher._invalid_log_group) + self.fetcher.load_time_period(start_time=start, end_time=end, filter_pattern=pattern) + self.assertTrue(self.fetcher._invalid_log_group) + self.fetcher.load_time_period(start_time=start, end_time=end, filter_pattern=pattern) + self.assertTrue(self.fetcher._invalid_log_group) + self.fetcher.load_time_period(start_time=start, end_time=end, filter_pattern=pattern) + self.assertFalse(self.fetcher._invalid_log_group) + def test_must_paginate_using_next_token(self): """Make three API calls, first two returns a nextToken and last does not.""" token = "token" @@ -320,3 +351,31 @@ def test_without_start_time(self, time_mock): self.assertEqual([], expected_consumer_call_args) self.assertEqual(expected_load_time_period_calls, patched_load_time_period.call_args_list) self.assertEqual(expected_sleep_calls, time_mock.sleep.call_args_list) + + @patch("samcli.lib.observability.cw_logs.cw_log_puller.time") + def test_with_throttling(self, time_mock): + expected_params = { + "logGroupName": self.log_group_name, + "interleaved": True, + "startTime": 0, + "filterPattern": self.filter_pattern, + } + + for _ in range(self.max_retries): + self.client_stubber.add_client_error( + "filter_log_events", expected_params=expected_params, service_error_code="ThrottlingException" + ) + + expected_load_time_period_calls = [call(to_datetime(0), filter_pattern=ANY) for _ in range(self.max_retries)] + + expected_time_calls = [call(2), call(4), call(16)] + + with patch.object( + self.fetcher, "load_time_period", wraps=self.fetcher.load_time_period + ) as patched_load_time_period: + with self.client_stubber: + self.fetcher.tail(filter_pattern=self.filter_pattern) + + self.consumer.consume.assert_not_called() + self.assertEqual(expected_load_time_period_calls, patched_load_time_period.call_args_list) + time_mock.sleep.assert_has_calls(expected_time_calls, any_order=True) diff --git a/tests/unit/lib/observability/test_observability_info_puller.py b/tests/unit/lib/observability/test_observability_info_puller.py index 3fbbb9fe34..2b3b6b2016 100644 --- a/tests/unit/lib/observability/test_observability_info_puller.py +++ b/tests/unit/lib/observability/test_observability_info_puller.py @@ -1,9 +1,12 @@ from unittest import TestCase -from unittest.mock import Mock +from unittest.mock import Mock, patch, call from parameterized import parameterized, param -from samcli.lib.observability.observability_info_puller import ObservabilityEventConsumerDecorator +from samcli.lib.observability.observability_info_puller import ( + ObservabilityEventConsumerDecorator, + ObservabilityCombinedPuller, +) class TestObservabilityEventConsumerDecorator(TestCase): @@ -48,3 +51,105 @@ def test_decorator_with_mappers(self, mappers): actual_consumer.consume.assert_called_with(event) for mapper in mappers: mapper.map.assert_called_with(event) + + +class TestObservabilityCombinedPuller(TestCase): + @patch("samcli.lib.observability.observability_info_puller.AsyncContext") + def test_tail(self, patched_async_context): + mocked_async_context = Mock() + patched_async_context.return_value = mocked_async_context + + mock_puller_1 = Mock() + mock_puller_2 = Mock() + + combined_puller = ObservabilityCombinedPuller([mock_puller_1, mock_puller_2]) + + given_start_time = Mock() + given_filter_pattern = Mock() + combined_puller.tail(given_start_time, given_filter_pattern) + + patched_async_context.assert_called_once() + mocked_async_context.assert_has_calls( + [ + call.add_async_task(mock_puller_1.tail, given_start_time, given_filter_pattern), + call.add_async_task(mock_puller_2.tail, given_start_time, given_filter_pattern), + call.run_async(), + ] + ) + + @patch("samcli.lib.observability.observability_info_puller.AsyncContext") + def test_tail_cancel(self, patched_async_context): + mocked_async_context = Mock() + mocked_async_context.run_async.side_effect = KeyboardInterrupt() + patched_async_context.return_value = mocked_async_context + + mock_puller_1 = Mock() + mock_puller_2 = Mock() + + combined_puller = ObservabilityCombinedPuller([mock_puller_1, mock_puller_2]) + + given_start_time = Mock() + given_filter_pattern = Mock() + combined_puller.tail(given_start_time, given_filter_pattern) + + patched_async_context.assert_called_once() + mocked_async_context.assert_has_calls( + [ + call.add_async_task(mock_puller_1.tail, given_start_time, given_filter_pattern), + call.add_async_task(mock_puller_2.tail, given_start_time, given_filter_pattern), + call.run_async(), + ] + ) + + self.assertTrue(mock_puller_1.cancelled) + self.assertTrue(mock_puller_2.cancelled) + + @patch("samcli.lib.observability.observability_info_puller.AsyncContext") + def test_load_time_period(self, patched_async_context): + mocked_async_context = Mock() + patched_async_context.return_value = mocked_async_context + + mock_puller_1 = Mock() + mock_puller_2 = Mock() + + combined_puller = ObservabilityCombinedPuller([mock_puller_1, mock_puller_2]) + + given_start_time = Mock() + given_end_time = Mock() + given_filter_pattern = Mock() + combined_puller.load_time_period(given_start_time, given_end_time, given_filter_pattern) + + patched_async_context.assert_called_once() + mocked_async_context.assert_has_calls( + [ + call.add_async_task( + mock_puller_1.load_time_period, given_start_time, given_end_time, given_filter_pattern + ), + call.add_async_task( + mock_puller_2.load_time_period, given_start_time, given_end_time, given_filter_pattern + ), + call.run_async(), + ] + ) + + @patch("samcli.lib.observability.observability_info_puller.AsyncContext") + def test_load_events(self, patched_async_context): + mocked_async_context = Mock() + patched_async_context.return_value = mocked_async_context + + mock_puller_1 = Mock() + mock_puller_2 = Mock() + + combined_puller = ObservabilityCombinedPuller([mock_puller_1, mock_puller_2]) + + given_event_ids = [Mock(), Mock()] + combined_puller.load_events(given_event_ids) + + patched_async_context.assert_called_once() + mocked_async_context.assert_has_calls( + [ + call.add_async_task(mock_puller_1.load_events, given_event_ids), + call.add_async_task(mock_puller_2.load_events, given_event_ids), + call.run_async(), + ] + ) diff --git a/tests/unit/lib/observability/xray_traces/test_xray_event_mappers.py b/tests/unit/lib/observability/xray_traces/test_xray_event_mappers.py new file mode 100644 index 0000000000..6abd276c74 --- /dev/null +++ b/tests/unit/lib/observability/xray_traces/test_xray_event_mappers.py @@ -0,0 +1,203 @@ +import json +import time +import uuid +import logging +from datetime import datetime, timezone +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from samcli.lib.observability.xray_traces.xray_event_mappers import ( + XRayTraceConsoleMapper, + XRayTraceJSONMapper, + XRayServiceGraphConsoleMapper, + XRayServiceGraphJSONMapper, +) +from samcli.lib.observability.xray_traces.xray_events import XRayTraceEvent, XRayServiceGraphEvent +from samcli.lib.utils.time import to_utc, utc_to_timestamp, timestamp_to_iso + +LOG = logging.getLogger() +logging.basicConfig() + + +class AbstraceXRayTraceMapperTest(TestCase): + def setUp(self): + self.trace_event = XRayTraceEvent( + { + "Id": str(uuid.uuid4()), + "Duration": 2.1, + "Segments": [ + { + "Id": str(uuid.uuid4()), + "Document": json.dumps( + { + "name": str(uuid.uuid4()), + "start_time": 1634603579.27, # 2021-10-18T17:32:59.270000 + "end_time": time.time(), + "http": {"response": {"status": 200}}, + } + ), + }, + { + "Id": str(uuid.uuid4()), + "Document": json.dumps( + { + "name": str(uuid.uuid4()), + "start_time": time.time(), + "end_time": time.time(), + "http": {"response": {"status": 200}}, + "subsegments": [ + { + "Id": str(uuid.uuid4()), + "name": str(uuid.uuid4()), + "start_time": time.time(), + "end_time": time.time(), + "http": {"response": {"status": 200}}, + } + ], + } + ), + }, + ], + } + ) + + +class TestXRayTraceConsoleMapper(AbstraceXRayTraceMapperTest): + def test_console_mapper(self): + with patch("samcli.lib.observability.xray_traces.xray_event_mappers.datetime") as fromtimestamp_mock: + fromtimestamp_mock.side_effect = lambda *args, **kw: datetime(*args, **kw) + fromtimestamp_mock.fromtimestamp.return_value = datetime(2021, 10, 18, 17, 32, 59, 270000) + + console_mapper = XRayTraceConsoleMapper() + mapped_event = console_mapper.map(self.trace_event) + + self.assertTrue(isinstance(mapped_event, XRayTraceEvent)) + + event_timestamp = "2021-10-18T17:32:59.270000" + LOG.info(mapped_event.message) + self.assertTrue( + f"XRay Event at ({event_timestamp}) with id ({self.trace_event.id}) and duration ({self.trace_event.duration:.3f}s)" + in mapped_event.message + ) + + self.validate_segments(self.trace_event.segments, mapped_event.message) + + def validate_segments(self, segments, message): + for segment in segments: + + if segment.http_status: + self.assertTrue( + f" - {segment.get_duration():.3f}s - {segment.name} [HTTP: {segment.http_status}]" in message + ) + else: + self.assertTrue(f" - {segment.get_duration():.3f}s - {segment.name}" in message) + self.validate_segments(segment.sub_segments, message) + + +class TestXRayTraceJSONMapper(AbstraceXRayTraceMapperTest): + def test_escaped_json_will_be_dict(self): + json_mapper = XRayTraceJSONMapper() + mapped_event = json_mapper.map(self.trace_event) + + segments = mapped_event.event.get("Segments") + self.assertTrue(isinstance(segments, list)) + for segment in segments: + self.assertTrue(isinstance(segment, dict)) + self.assertEqual(mapped_event.event, json.loads(mapped_event.message)) + + +class AbstractXRayServiceGraphMapperTest(TestCase): + def setUp(self): + self.service_graph_event = XRayServiceGraphEvent( + { + "StartTime": datetime(2015, 1, 1), + "EndTime": datetime(2015, 1, 1), + "Services": [ + { + "ReferenceId": 123, + "Name": "string", + "Root": True | False, + "Type": "string", + "StartTime": datetime(2015, 1, 1), + "EndTime": datetime(2015, 1, 1), + "Edges": [ + { + "ReferenceId": 123, + "StartTime": datetime(2015, 1, 1), + "EndTime": datetime(2015, 1, 1), + }, + ], + "SummaryStatistics": { + "OkCount": 123, + "ErrorStatistics": {"TotalCount": 123}, + "FaultStatistics": {"TotalCount": 123}, + "TotalCount": 123, + "TotalResponseTime": 123.0, + }, + }, + ], + } + ) + + +class TestXRayServiceGraphConsoleMapper(AbstractXRayServiceGraphMapperTest): + def test_console_mapper(self): + console_mapper = XRayServiceGraphConsoleMapper() + mapped_event = console_mapper.map(self.service_graph_event) + + self.assertTrue(isinstance(mapped_event, XRayServiceGraphEvent)) + + self.assertTrue(f"\nNew XRay Service Graph" in mapped_event.message) + self.assertTrue(f"\n Start time: {self.service_graph_event.start_time}" in mapped_event.message) + self.assertTrue(f"\n End time: {self.service_graph_event.end_time}" in mapped_event.message) + + self.validate_services(self.service_graph_event.services, mapped_event.message) + + def validate_services(self, services, message): + for service in services: + self.assertTrue(f"Reference Id: {service.id}" in message) + if service.is_root: + self.assertTrue("(Root)" in message) + else: + self.assertFalse("(Root)" in message) + self.assertTrue(f" {service.type} - {service.name}" in message) + edg_id_str = str(service.edge_ids) + self.assertTrue(f"Edges: {edg_id_str}" in message) + self.validate_summary_statistics(service, message) + + def validate_summary_statistics(self, service, message): + self.assertTrue("Summary_statistics:" in message) + self.assertTrue(f"total requests: {service.total_count}" in message) + self.assertTrue(f"ok count(2XX): {service.ok_count}" in message) + self.assertTrue(f"error count(4XX): {service.error_count}" in message) + self.assertTrue(f"fault count(5XX): {service.fault_count}" in message) + self.assertTrue(f"total response time: {service.response_time}" in message) + + +class TestXRayServiceGraphFileMapper(AbstractXRayServiceGraphMapperTest): + def test_datetime_object_convert_to_iso_string(self): + actual_datetime = datetime(2015, 1, 1) + json_mapper = XRayServiceGraphJSONMapper() + mapped_event = json_mapper.map(self.service_graph_event) + mapped_dict = mapped_event.event + + self.validate_start_and_end_time(actual_datetime, mapped_dict) + services = mapped_dict.get("Services", []) + for service in services: + self.validate_start_and_end_time(actual_datetime, service) + edges = service.get("Edges", []) + for edge in edges: + self.validate_start_and_end_time(actual_datetime, edge) + self.assertEqual(mapped_event.event, json.loads(mapped_event.message)) + + def validate_start_and_end_time(self, datetime_obj, event_dict): + self.validate_datetime_object_to_iso_string("StartTime", datetime_obj, event_dict) + self.validate_datetime_object_to_iso_string("EndTime", datetime_obj, event_dict) + + def validate_datetime_object_to_iso_string(self, datetime_key, datetime_obj, event_dict): + datetime_str = event_dict.get(datetime_key) + self.assertTrue(isinstance(datetime_str, str)) + expected_utc_datetime = to_utc(datetime_obj) + expected_timestamp = utc_to_timestamp(expected_utc_datetime) + expected_iso_str = timestamp_to_iso(expected_timestamp) + self.assertEqual(datetime_str, expected_iso_str) diff --git a/tests/unit/lib/observability/xray_traces/test_xray_event_puller.py b/tests/unit/lib/observability/xray_traces/test_xray_event_puller.py new file mode 100644 index 0000000000..a7b737aa54 --- /dev/null +++ b/tests/unit/lib/observability/xray_traces/test_xray_event_puller.py @@ -0,0 +1,166 @@ +import time +import uuid +from itertools import zip_longest +from unittest import TestCase +from unittest.mock import patch, mock_open, call, Mock, ANY + +from botocore.exceptions import ClientError +from parameterized import parameterized + +from samcli.lib.observability.xray_traces.xray_event_puller import XRayTracePuller + + +class TestXrayTracePuller(TestCase): + def setUp(self): + self.xray_client = Mock() + self.consumer = Mock() + + self.max_retries = 4 + self.xray_trace_puller = XRayTracePuller(self.xray_client, self.consumer, self.max_retries) + + @parameterized.expand([(i,) for i in range(1, 15)]) + @patch("samcli.lib.observability.xray_traces.xray_event_puller.XRayTraceEvent") + def test_load_events(self, size, patched_xray_trace_event): + ids = [str(uuid.uuid4()) for _ in range(size)] + batch_ids = list(zip_longest(*([iter(ids)] * 5))) + + given_paginators = [Mock() for _ in batch_ids] + self.xray_client.get_paginator.side_effect = given_paginators + + given_results = [] + for i in range(len(batch_ids)): + given_result = [{"Traces": [Mock() for _ in batch]} for batch in batch_ids] + given_paginators[i].paginate.return_value = given_result + given_results.append(given_result) + + collected_events = [] + + def dynamic_mock(trace): + mocked_trace_event = Mock(trace=trace) + mocked_trace_event.get_latest_event_time.return_value = time.time() + collected_events.append(mocked_trace_event) + return mocked_trace_event + + patched_xray_trace_event.side_effect = dynamic_mock + + self.xray_trace_puller.load_events(ids) + + for i in range(len(batch_ids)): + self.xray_client.get_paginator.assert_called_with("batch_get_traces") + given_paginators[i].assert_has_calls([call.paginate(TraceIds=list(filter(None, batch_ids[i])))]) + self.consumer.assert_has_calls([call.consume(event) for event in collected_events]) + for event in collected_events: + event.get_latest_event_time.assert_called_once() + + def test_load_events_with_no_event_ids(self): + self.xray_trace_puller.load_events([]) + self.consumer.assert_not_called() + + def test_load_events_with_no_event_returned(self): + event_ids = [str(uuid.uuid4())] + + given_paginator = Mock() + given_paginator.paginate.return_value = [{"Traces": []}] + self.xray_client.get_paginator.return_value = given_paginator + + self.xray_trace_puller.load_events(event_ids) + given_paginator.paginate.assert_called_with(TraceIds=event_ids) + self.consumer.assert_not_called() + + def test_load_time_period(self): + given_paginator = Mock() + self.xray_client.get_paginator.return_value = given_paginator + + given_trace_summaries = [{"TraceSummaries": [{"Id": str(uuid.uuid4())} for _ in range(10)]}] + given_paginator.paginate.return_value = given_trace_summaries + + start_time = "start_time" + end_time = "end_time" + with patch.object(self.xray_trace_puller, "load_events") as patched_load_events: + self.xray_trace_puller.load_time_period(start_time, end_time) + given_paginator.paginate.assert_called_with(TimeRangeType="TraceId", StartTime=start_time, EndTime=end_time) + + collected_trace_ids = [item.get("Id") for item in given_trace_summaries[0].get("TraceSummaries", [])] + patched_load_events.assert_called_with(collected_trace_ids) + + def test_load_time_period_with_partial_result(self): + given_paginator = Mock() + self.xray_client.get_paginator.return_value = given_paginator + + given_trace_summaries = [{"TraceSummaries": [{"Id": str(uuid.uuid4()), "IsPartial": True} for _ in range(10)]}] + given_paginator.paginate.return_value = given_trace_summaries + + start_time = "start_time" + end_time = "end_time" + with patch.object(self.xray_trace_puller, "load_events") as patched_load_events: + self.xray_trace_puller.load_time_period(start_time, end_time) + given_paginator.paginate.assert_called_with(TimeRangeType="TraceId", StartTime=start_time, EndTime=end_time) + + patched_load_events.assert_called_with([]) + + @patch("samcli.lib.observability.xray_traces.xray_event_puller.time") + @patch("samcli.lib.observability.xray_traces.xray_event_puller.to_timestamp") + @patch("samcli.lib.observability.xray_traces.xray_event_puller.to_datetime") + def test_tail_with_no_data(self, patched_to_datetime, patched_to_timestamp, patched_time): + start_time = Mock() + + with patch.object(self.xray_trace_puller, "load_time_period") as patched_load_time_period: + self.xray_trace_puller.tail(start_time) + + patched_to_timestamp.assert_called_with(start_time) + + patched_to_datetime.assert_has_calls( + [call(self.xray_trace_puller.latest_event_time) for _ in range(self.max_retries)] + ) + + patched_time.sleep.assert_has_calls( + [call(self.xray_trace_puller._poll_interval) for _ in range(self.max_retries)] + ) + + patched_load_time_period.assert_has_calls([call(ANY, ANY) for _ in range(self.max_retries)]) + + @patch("samcli.lib.observability.xray_traces.xray_event_puller.time") + @patch("samcli.lib.observability.xray_traces.xray_event_puller.to_timestamp") + @patch("samcli.lib.observability.xray_traces.xray_event_puller.to_datetime") + def test_tail_with_with_data(self, patched_to_datetime, patched_to_timestamp, patched_time): + start_time = Mock() + given_start_time = 5 + patched_to_timestamp.return_value = 5 + with patch.object(self.xray_trace_puller, "_had_data") as patched_had_data: + patched_had_data.side_effect = [True, False] + + with patch.object(self.xray_trace_puller, "load_time_period") as patched_load_time_period: + self.xray_trace_puller.tail(start_time) + + patched_to_timestamp.assert_called_with(start_time) + + patched_to_datetime.assert_has_calls( + [ + call(given_start_time), + ], + any_order=True, + ) + patched_to_datetime.assert_has_calls([call(given_start_time + 1) for _ in range(self.max_retries)]) + + patched_time.sleep.assert_has_calls( + [call(self.xray_trace_puller._poll_interval) for _ in range(self.max_retries + 1)] + ) + + patched_load_time_period.assert_has_calls([call(ANY, ANY) for _ in range(self.max_retries + 1)]) + + @patch("samcli.lib.observability.xray_traces.xray_event_puller.time") + def test_with_throttling(self, patched_time): + with patch.object( + self.xray_trace_puller, "load_time_period", wraps=self.xray_trace_puller.load_time_period + ) as patched_load_time_period: + patched_load_time_period.side_effect = [ + ClientError({"Error": {"Code": "ThrottlingException"}}, "operation") for _ in range(self.max_retries) + ] + + self.xray_trace_puller.tail() + + patched_load_time_period.assert_has_calls([call(ANY, ANY) for _ in range(self.max_retries)]) + + patched_time.sleep.assert_has_calls([call(2), call(4), call(16), call(256)]) + + self.assertEqual(self.xray_trace_puller._poll_interval, 256) diff --git a/tests/unit/lib/observability/xray_traces/test_xray_events.py b/tests/unit/lib/observability/xray_traces/test_xray_events.py new file mode 100644 index 0000000000..c1db606dd1 --- /dev/null +++ b/tests/unit/lib/observability/xray_traces/test_xray_events.py @@ -0,0 +1,198 @@ +import json +import time +import uuid +from unittest import TestCase + +from samcli.lib.observability.xray_traces.xray_events import XRayTraceSegment, XRayTraceEvent, XRayServiceGraphEvent +from samcli.lib.utils.hash import str_checksum + +LATEST_EVENT_TIME = 9621490723 + + +class AbstractXRayEventTextTest(TestCase): + def validate_segment(self, segment, event_dict): + self.assertEqual(segment.id, event_dict.get("Id")) + self.assertEqual(segment.name, event_dict.get("name")) + self.assertEqual(segment.start_time, event_dict.get("start_time")) + self.assertEqual(segment.end_time, event_dict.get("end_time")) + self.assertEqual(segment.http_status, event_dict.get("http", {}).get("response", {}).get("status", None)) + event_subsegments = event_dict.get("subsegments", []) + self.assertEqual(len(segment.sub_segments), len(event_subsegments)) + + for event_subsegment in event_subsegments: + subsegment = next(x for x in segment.sub_segments if x.id == event_subsegment.get("Id")) + self.validate_segment(subsegment, event_subsegment) + + +class TestXRayTraceEvent(AbstractXRayEventTextTest): + def setUp(self): + self.first_segment_date = time.time() - 1000 + self.segment_1 = { + "Id": str(uuid.uuid4()), + "name": f"Second {str(uuid.uuid4())}", + "start_time": time.time(), + "end_time": time.time(), + "http": {"response": {"status": 200}}, + } + self.segment_2 = { + "Id": str(uuid.uuid4()), + "name": f"First {str(uuid.uuid4())}", + "start_time": self.first_segment_date, + "end_time": LATEST_EVENT_TIME, + "http": {"response": {"status": 200}}, + } + self.event_dict = { + "Id": str(uuid.uuid4()), + "Duration": 400, + "Segments": [ + {"Id": self.segment_1.get("Id"), "Document": json.dumps(self.segment_1)}, + {"Id": self.segment_2.get("Id"), "Document": json.dumps(self.segment_2)}, + ], + } + + def test_xray_trace_event(self): + xray_trace_event = XRayTraceEvent(self.event_dict) + self.assertEqual(xray_trace_event.id, self.event_dict.get("Id")) + self.assertEqual(xray_trace_event.duration, self.event_dict.get("Duration")) + segments = self.event_dict.get("Segments", []) + self.assertEqual(len(xray_trace_event.segments), len(segments)) + + for segment in segments: + subsegment = next(x for x in xray_trace_event.segments if x.id == segment.get("Id")) + self.validate_segment(subsegment, json.loads(segment.get("Document"))) + + def test_latest_event_time(self): + xray_trace_event = XRayTraceEvent(self.event_dict) + self.assertEqual(xray_trace_event.get_latest_event_time(), LATEST_EVENT_TIME) + + def test_first_event_time(self): + xray_trace_event = XRayTraceEvent(self.event_dict) + self.assertEqual(xray_trace_event.timestamp, self.first_segment_date) + + def test_segment_order(self): + xray_trace_event = XRayTraceEvent(self.event_dict) + + self.assertEqual(len(xray_trace_event.segments), 2) + self.assertIn("First", xray_trace_event.segments[0].name) + self.assertIn("Second", xray_trace_event.segments[1].name) + + +class TestXRayTraceSegment(AbstractXRayEventTextTest): + def setUp(self): + self.event_dict = { + "Id": uuid.uuid4(), + "name": uuid.uuid4(), + "start_time": time.time(), + "end_time": time.time(), + "http": {"response": {"status": 200}}, + "subsegments": [ + { + "Id": uuid.uuid4(), + "name": uuid.uuid4(), + "start_time": time.time(), + "end_time": time.time(), + "http": {"response": {"status": 200}}, + }, + { + "Id": uuid.uuid4(), + "name": uuid.uuid4(), + "start_time": time.time(), + "end_time": time.time(), + "http": {"response": {"status": 200}}, + "subsegments": [ + { + "Id": uuid.uuid4(), + "name": uuid.uuid4(), + "start_time": time.time(), + "end_time": LATEST_EVENT_TIME, + "http": {"response": {"status": 200}}, + } + ], + }, + ], + } + + def test_xray_trace_segment_duration(self): + xray_trace_segment = XRayTraceSegment(self.event_dict) + self.assertEqual( + xray_trace_segment.get_duration(), self.event_dict.get("end_time") - self.event_dict.get("start_time") + ) + + def test_xray_latest_event_time(self): + xray_trace_segment = XRayTraceSegment(self.event_dict) + self.assertEqual(xray_trace_segment.get_latest_event_time(), LATEST_EVENT_TIME) + + def test_xray_trace_segment(self): + xray_trace_segment = XRayTraceSegment(self.event_dict) + self.validate_segment(xray_trace_segment, self.event_dict) + + +class AbstractXRayServiceTest(TestCase): + def validate_service(self, service, service_dict): + self.assertEqual(service.id, service_dict.get("ReferenceId")) + self.assertEqual(service.name, service_dict.get("Name")) + self.assertEqual(service.is_root, service_dict.get("Root")) + self.assertEqual(service.type, service_dict.get("Type")) + self.assertEqual(service.name, service_dict.get("Name")) + edges = service_dict.get("Edges") + self.assertEqual(len(service.edge_ids), len(edges)) + summary_statistics = service_dict.get("SummaryStatistics") + self.assertEqual(service.ok_count, summary_statistics.get("OkCount")) + self.assertEqual(service.error_count, summary_statistics.get("ErrorStatistics").get("TotalCount")) + self.assertEqual(service.fault_count, summary_statistics.get("FaultStatistics").get("TotalCount")) + self.assertEqual(service.total_count, summary_statistics.get("TotalCount")) + self.assertEqual(service.response_time, summary_statistics.get("TotalResponseTime")) + + +class TestXRayServiceGraphEvent(AbstractXRayServiceTest): + def setUp(self): + self.service_1 = { + "ReferenceId": 0, + "Name": "test1", + "Root": True, + "Type": "Lambda", + "Edges": [ + { + "ReferenceId": 1, + }, + ], + "SummaryStatistics": { + "OkCount": 1, + "ErrorStatistics": {"TotalCount": 2}, + "FaultStatistics": {"TotalCount": 3}, + "TotalCount": 6, + "TotalResponseTime": 123.0, + }, + } + + self.service_2 = { + "ReferenceId": 1, + "Name": "test2", + "Root": False, + "Type": "Api", + "Edges": [], + "SummaryStatistics": { + "OkCount": 2, + "ErrorStatistics": {"TotalCount": 3}, + "FaultStatistics": {"TotalCount": 3}, + "TotalCount": 8, + "TotalResponseTime": 200.0, + }, + } + self.event_dict = { + "Services": [self.service_1, self.service_2], + } + + def test_xray_service_graph_event(self): + xray_service_graph_event = XRayServiceGraphEvent(self.event_dict) + services_array = self.event_dict.get("Services", []) + services = xray_service_graph_event.services + self.assertEqual(len(services), len(services_array)) + + for service, service_dict in zip(services, services_array): + self.validate_service(service, service_dict) + + def test__xray_service_graph_event_get_hash(self): + xray_service_graph_event = XRayServiceGraphEvent(self.event_dict) + expected_hash = str_checksum(str(self.event_dict["Services"])) + self.assertEqual(expected_hash, xray_service_graph_event.get_hash()) diff --git a/tests/unit/lib/observability/xray_traces/test_xray_service_grpah_event_puller.py b/tests/unit/lib/observability/xray_traces/test_xray_service_grpah_event_puller.py new file mode 100644 index 0000000000..8bd604a6ee --- /dev/null +++ b/tests/unit/lib/observability/xray_traces/test_xray_service_grpah_event_puller.py @@ -0,0 +1,146 @@ +import time +import uuid +from itertools import zip_longest +from unittest import TestCase +from unittest.mock import patch, mock_open, call, Mock, ANY + +from botocore.exceptions import ClientError +from parameterized import parameterized + +from samcli.lib.observability.xray_traces.xray_event_puller import XRayTracePuller +from samcli.lib.observability.xray_traces.xray_service_graph_event_puller import XRayServiceGraphPuller + + +class TestXRayServiceGraphPuller(TestCase): + def setUp(self): + self.xray_client = Mock() + self.consumer = Mock() + + self.max_retries = 4 + self.xray_service_graph_puller = XRayServiceGraphPuller(self.xray_client, self.consumer, self.max_retries) + + @patch("samcli.lib.observability.xray_traces.xray_service_graph_event_puller.XRayServiceGraphEvent") + @patch("samcli.lib.observability.xray_traces.xray_service_graph_event_puller.to_utc") + @patch("samcli.lib.observability.xray_traces.xray_service_graph_event_puller.utc_to_timestamp") + def test_load_time_period(self, patched_utc_to_timestamp, patched_to_utc, patched_xray_service_graph_event): + given_paginator = Mock() + self.xray_client.get_paginator.return_value = given_paginator + + given_services = [{"EndTime": "endtime", "Services": [{"id": 1}]}] + given_paginator.paginate.return_value = given_services + + start_time = "start_time" + end_time = "end_time" + patched_utc_to_timestamp.return_value = 1 + self.xray_service_graph_puller.load_time_period(start_time, end_time) + patched_utc_to_timestamp.assert_called() + patched_to_utc.assert_called() + given_paginator.paginate.assert_called_with(StartTime=start_time, EndTime=end_time) + patched_xray_service_graph_event.assrt_called_with({"EndTime": "endtime", "Services": [{"id": 1}]}) + self.consumer.consume.assert_called() + + @patch("samcli.lib.observability.xray_traces.xray_service_graph_event_puller.XRayServiceGraphEvent") + @patch("samcli.lib.observability.xray_traces.xray_service_graph_event_puller.to_utc") + @patch("samcli.lib.observability.xray_traces.xray_service_graph_event_puller.utc_to_timestamp") + def test_load_time_period_with_same_event_twice( + self, patched_utc_to_timestamp, patched_to_utc, patched_xray_service_graph_event + ): + given_paginator = Mock() + self.xray_client.get_paginator.return_value = given_paginator + + given_services = [{"EndTime": "endtime", "Services": [{"id": 1}]}] + given_paginator.paginate.return_value = given_services + + start_time = "start_time" + end_time = "end_time" + patched_utc_to_timestamp.return_value = 1 + self.xray_service_graph_puller.load_time_period(start_time, end_time) + # called with the same event twice + self.xray_service_graph_puller.load_time_period(start_time, end_time) + patched_utc_to_timestamp.assert_called() + patched_to_utc.assert_called() + given_paginator.paginate.assert_called_with(StartTime=start_time, EndTime=end_time) + patched_xray_service_graph_event.assrt_called_with({"EndTime": "endtime", "Services": [{"id": 1}]}) + # consumer should only get called once + self.consumer.consume.assert_called_once() + + @patch("samcli.lib.observability.xray_traces.xray_service_graph_event_puller.XRayServiceGraphEvent") + def test_load_time_period_with_no_service(self, patched_xray_service_graph_event): + given_paginator = Mock() + self.xray_client.get_paginator.return_value = given_paginator + + given_services = [{"EndTime": "endtime", "Services": []}] + given_paginator.paginate.return_value = given_services + + start_time = "start_time" + end_time = "end_time" + self.xray_service_graph_puller.load_time_period(start_time, end_time) + patched_xray_service_graph_event.assert_not_called() + self.consumer.consume.assert_not_called() + + @patch("samcli.lib.observability.xray_traces.xray_event_puller.time") + @patch("samcli.lib.observability.xray_traces.xray_event_puller.to_timestamp") + @patch("samcli.lib.observability.xray_traces.xray_event_puller.to_datetime") + def test_tail_with_no_data(self, patched_to_datetime, patched_to_timestamp, patched_time): + start_time = Mock() + + with patch.object(self.xray_service_graph_puller, "load_time_period") as patched_load_time_period: + self.xray_service_graph_puller.tail(start_time) + + patched_to_timestamp.assert_called_with(start_time) + + patched_to_datetime.assert_has_calls( + [call(self.xray_service_graph_puller.latest_event_time) for _ in range(self.max_retries)] + ) + + patched_time.sleep.assert_has_calls( + [call(self.xray_service_graph_puller._poll_interval) for _ in range(self.max_retries)] + ) + + patched_load_time_period.assert_has_calls([call(ANY, ANY) for _ in range(self.max_retries)]) + + @patch("samcli.lib.observability.xray_traces.xray_event_puller.time") + @patch("samcli.lib.observability.xray_traces.xray_event_puller.to_timestamp") + @patch("samcli.lib.observability.xray_traces.xray_event_puller.to_datetime") + def test_tail_with_with_data(self, patched_to_datetime, patched_to_timestamp, patched_time): + start_time = Mock() + given_start_time = 5 + patched_to_timestamp.return_value = 5 + with patch.object(self.xray_service_graph_puller, "_had_data") as patched_had_data: + patched_had_data.side_effect = [True, False] + + with patch.object(self.xray_service_graph_puller, "load_time_period") as patched_load_time_period: + self.xray_service_graph_puller.tail(start_time) + + patched_to_timestamp.assert_called_with(start_time) + + patched_to_datetime.assert_has_calls( + [ + call(given_start_time), + ], + any_order=True, + ) + patched_to_datetime.assert_has_calls([call(given_start_time + 1) for _ in range(self.max_retries)]) + + patched_time.sleep.assert_has_calls( + [call(self.xray_service_graph_puller._poll_interval) for _ in range(self.max_retries + 1)] + ) + + patched_load_time_period.assert_has_calls([call(ANY, ANY) for _ in range(self.max_retries + 1)]) + + @patch("samcli.lib.observability.xray_traces.xray_event_puller.time") + def test_with_throttling(self, patched_time): + with patch.object( + self.xray_service_graph_puller, "load_time_period", wraps=self.xray_service_graph_puller.load_time_period + ) as patched_load_time_period: + patched_load_time_period.side_effect = [ + ClientError({"Error": {"Code": "ThrottlingException"}}, "operation") for _ in range(self.max_retries) + ] + + self.xray_service_graph_puller.tail() + + patched_load_time_period.assert_has_calls([call(ANY, ANY) for _ in range(self.max_retries)]) + + patched_time.sleep.assert_has_calls([call(2), call(4), call(16), call(256)]) + + self.assertEqual(self.xray_service_graph_puller._poll_interval, 256) diff --git a/tests/unit/lib/sync/__init__.py b/tests/unit/lib/sync/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/lib/sync/flows/__init__.py b/tests/unit/lib/sync/flows/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/lib/sync/flows/test_alias_version_sync_flow.py b/tests/unit/lib/sync/flows/test_alias_version_sync_flow.py new file mode 100644 index 0000000000..ff7762473f --- /dev/null +++ b/tests/unit/lib/sync/flows/test_alias_version_sync_flow.py @@ -0,0 +1,61 @@ +import os +import hashlib + +from samcli.lib.sync.sync_flow import SyncFlow +from unittest import TestCase +from unittest.mock import ANY, MagicMock, call, mock_open, patch + +from samcli.lib.sync.flows.alias_version_sync_flow import AliasVersionSyncFlow + + +class TestAliasVersionSyncFlow(TestCase): + def create_sync_flow(self): + sync_flow = AliasVersionSyncFlow( + "Function1", + "Alias1", + build_context=MagicMock(), + deploy_context=MagicMock(), + physical_id_mapping={}, + stacks=[MagicMock()], + ) + return sync_flow + + @patch("samcli.lib.sync.sync_flow.Session") + def test_set_up(self, session_mock): + sync_flow = self.create_sync_flow() + sync_flow.set_up() + session_mock.return_value.client.assert_any_call("lambda") + + @patch("samcli.lib.sync.sync_flow.Session") + def test_sync_direct(self, session_mock): + sync_flow = self.create_sync_flow() + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalFunction1" + + sync_flow.set_up() + + sync_flow._lambda_client.publish_version.return_value = {"Version": "2"} + + sync_flow.sync() + + sync_flow._lambda_client.publish_version.assert_called_once_with(FunctionName="PhysicalFunction1") + sync_flow._lambda_client.update_alias.assert_called_once_with( + FunctionName="PhysicalFunction1", Name="Alias1", FunctionVersion="2" + ) + + def test_equality_keys(self): + sync_flow = self.create_sync_flow() + self.assertEqual(sync_flow._equality_keys(), ("Function1", "Alias1")) + + def test_gather_dependencies(self): + sync_flow = self.create_sync_flow() + self.assertEqual(sync_flow.gather_dependencies(), []) + + def test_get_resource_api_calls(self): + sync_flow = self.create_sync_flow() + self.assertEqual(sync_flow._get_resource_api_calls(), []) + + def test_compare_remote(self): + sync_flow = self.create_sync_flow() + self.assertFalse(sync_flow.compare_remote()) diff --git a/tests/unit/lib/sync/flows/test_auto_dependency_layer_sync_flow.py b/tests/unit/lib/sync/flows/test_auto_dependency_layer_sync_flow.py new file mode 100644 index 0000000000..bfa0190976 --- /dev/null +++ b/tests/unit/lib/sync/flows/test_auto_dependency_layer_sync_flow.py @@ -0,0 +1,176 @@ +import os.path +from unittest import TestCase +from unittest.mock import Mock, patch, ANY + +from samcli.lib.build.build_graph import BuildGraph +from samcli.lib.sync.exceptions import ( + MissingFunctionBuildDefinition, + InvalidRuntimeDefinitionForFunction, + NoLayerVersionsFoundError, +) +from samcli.lib.sync.flows.auto_dependency_layer_sync_flow import ( + AutoDependencyLayerParentSyncFlow, + AutoDependencyLayerSyncFlow, +) +from samcli.lib.sync.flows.layer_sync_flow import FunctionLayerReferenceSync + + +class TestAutoDependencyLayerParentSyncFlow(TestCase): + def setUp(self) -> None: + self.sync_flow = AutoDependencyLayerParentSyncFlow( + "function_identifier", Mock(), Mock(stack_name="stack_name"), Mock(), [Mock()] + ) + + @patch("samcli.lib.sync.flows.auto_dependency_layer_sync_flow.super") + def test_gather_dependencies(self, patched_super): + patched_super.return_value.gather_dependencies.return_value = [] + with patch.object(self.sync_flow, "_build_graph") as patched_build_graph: + patched_build_graph.get_function_build_definitions.return_value = [Mock(download_dependencies=True)] + + dependencies = self.sync_flow.gather_dependencies() + self.assertEqual(len(dependencies), 1) + self.assertIsInstance(dependencies[0], AutoDependencyLayerSyncFlow) + + @patch("samcli.lib.sync.flows.auto_dependency_layer_sync_flow.super") + def test_skip_gather_dependencies(self, patched_super): + patched_super.return_value.gather_dependencies.return_value = [] + with patch.object(self.sync_flow, "_build_graph") as patched_build_graph: + patched_build_graph.get_function_build_definitions.return_value = [Mock(download_dependencies=False)] + + dependencies = self.sync_flow.gather_dependencies() + self.assertEqual(dependencies, []) + + def test_combine_dependencies(self): + self.assertFalse(self.sync_flow._combine_dependencies()) + + +class TestAutoDependencyLayerSyncFlow(TestCase): + def setUp(self) -> None: + self.build_graph = Mock(spec=BuildGraph) + self.stack_name = "stack_name" + self.build_dir = "build_dir" + self.function_identifier = "function_identifier" + self.sync_flow = AutoDependencyLayerSyncFlow( + self.function_identifier, + self.build_graph, + Mock(build_dir=self.build_dir), + Mock(stack_name=self.stack_name), + Mock(), + [Mock()], + ) + + def test_gather_resources_fail_when_no_function_build_definition_found(self): + self.build_graph.get_function_build_definitions.return_value = [] + with self.assertRaises(MissingFunctionBuildDefinition): + self.sync_flow.gather_resources() + + @patch("samcli.lib.sync.flows.auto_dependency_layer_sync_flow.SamFunctionProvider") + def test_gather_resources_fail_when_no_runtime_defined_for_function(self, patched_function_provider): + self.build_graph.get_function_build_definitions.return_value = [Mock()] + patched_function_provider.return_value.get.return_value = Mock(runtime=None) + with self.assertRaises(InvalidRuntimeDefinitionForFunction): + self.sync_flow.gather_resources() + + @patch("samcli.lib.sync.flows.auto_dependency_layer_sync_flow.uuid") + @patch("samcli.lib.sync.flows.auto_dependency_layer_sync_flow.file_checksum") + @patch("samcli.lib.sync.flows.auto_dependency_layer_sync_flow.make_zip") + @patch("samcli.lib.sync.flows.auto_dependency_layer_sync_flow.tempfile") + @patch("samcli.lib.sync.flows.auto_dependency_layer_sync_flow.NestedStackManager") + def test_gather_resources( + self, + patched_nested_stack_manager, + patched_tempfile, + patched_make_zip, + patched_file_checksum, + patched_uuid, + ): + layer_root_folder = "layer_root_folder" + dependencies_dir = "dependencies_dir" + tmpdir = "tmpdir" + uuid_hex = "uuid_hex" + runtime = "runtime" + zipfile = "zipfile" + + patched_nested_stack_manager.update_layer_folder.return_value = layer_root_folder + patched_tempfile.gettempdir.return_value = tmpdir + patched_uuid.uuid4.return_value = Mock(hex=uuid_hex) + patched_make_zip.return_value = zipfile + self.build_graph.get_function_build_definitions.return_value = [Mock(dependencies_dir=dependencies_dir)] + + with patch.object(self.sync_flow, "_get_compatible_runtimes") as patched_comp_runtimes: + patched_comp_runtimes.return_value = [runtime] + self.sync_flow.gather_resources() + + self.assertEqual(self.sync_flow._artifact_folder, layer_root_folder) + patched_nested_stack_manager.update_layer_folder.assert_called_with( + "build_dir", dependencies_dir, ANY, self.function_identifier, runtime + ) + patched_make_zip.assert_called_with( + os.path.join(tmpdir, f"data-{uuid_hex}"), self.sync_flow._artifact_folder + ) + patched_file_checksum.assert_called_with(zipfile, ANY) + + def test_empty_gather_dependencies(self): + with patch.object(self.sync_flow, "_get_dependent_functions") as patched_get_dependent_functions: + patched_get_dependent_functions.return_value = [] + self.assertEqual(self.sync_flow.gather_dependencies(), []) + + def test_gather_dependencies(self): + layer_identifier = "layer_identifier" + self.sync_flow._layer_identifier = layer_identifier + with patch.object(self.sync_flow, "_get_dependent_functions") as patched_get_dependent_functions: + patched_get_dependent_functions.return_value = [ + Mock(layers=[Mock(full_path=layer_identifier)], full_path="Function") + ] + dependencies = self.sync_flow.gather_dependencies() + self.assertEqual(len(dependencies), 1) + self.assertIsInstance(dependencies[0], FunctionLayerReferenceSync) + + @patch("samcli.lib.sync.flows.auto_dependency_layer_sync_flow.SamFunctionProvider") + def test_get_dependent_functions(self, patched_function_provider): + given_function_in_template = Mock() + patched_function_provider.return_value.get.return_value = given_function_in_template + + self.assertEqual(self.sync_flow._get_dependent_functions(), [given_function_in_template]) + + @patch("samcli.lib.sync.flows.auto_dependency_layer_sync_flow.SamFunctionProvider") + def test_get_compatible_runtimes(self, patched_function_provider): + given_runtime = "python3.9" + given_function_in_template = Mock(runtime=given_runtime) + patched_function_provider.return_value.get.return_value = given_function_in_template + + self.assertEqual(self.sync_flow._get_compatible_runtimes(), [given_runtime]) + + @patch("samcli.lib.sync.flows.auto_dependency_layer_sync_flow.NestedStackBuilder") + @patch("samcli.lib.sync.flows.auto_dependency_layer_sync_flow.super") + def test_setup(self, patched_super, patched_nested_stack_builder): + layer_name = "layer_name" + patched_nested_stack_builder.get_layer_name.return_value = layer_name + + patched_lambda_client = Mock() + self.sync_flow._lambda_client = patched_lambda_client + layer_physical_name = "layer_physical_name" + patched_lambda_client.list_layer_versions.return_value = { + "LayerVersions": [{"LayerVersionArn": f"{layer_physical_name}:0"}] + } + + self.sync_flow.set_up() + + self.assertEqual(self.sync_flow._layer_arn, layer_physical_name) + patched_nested_stack_builder.get_layer_name.assert_called_with( + self.sync_flow._deploy_context.stack_name, self.sync_flow._function_identifier + ) + patched_lambda_client.list_layer_versions.assert_called_with(LayerName=layer_name) + + @patch("samcli.lib.sync.flows.auto_dependency_layer_sync_flow.NestedStackBuilder") + @patch("samcli.lib.sync.flows.auto_dependency_layer_sync_flow.super") + def test_setup_with_no_layer_version(self, patched_super, patched_nested_stack_builder): + layer_name = "layer_name" + patched_nested_stack_builder.get_layer_name.return_value = layer_name + + patched_lambda_client = Mock() + self.sync_flow._lambda_client = patched_lambda_client + patched_lambda_client.list_layer_versions.return_value = {"LayerVersions": []} + + with self.assertRaises(NoLayerVersionsFoundError): + self.sync_flow.set_up() diff --git a/tests/unit/lib/sync/flows/test_function_sync_flow.py b/tests/unit/lib/sync/flows/test_function_sync_flow.py new file mode 100644 index 0000000000..c2d3e12f6c --- /dev/null +++ b/tests/unit/lib/sync/flows/test_function_sync_flow.py @@ -0,0 +1,50 @@ +from samcli.lib.providers.provider import ResourceIdentifier +from unittest import TestCase +from unittest.mock import ANY, MagicMock, call, patch + +from samcli.lib.sync.sync_flow import SyncFlow, ResourceAPICall +from samcli.lib.sync.flows.function_sync_flow import FunctionSyncFlow +from samcli.lib.utils.lock_distributor import LockChain + + +class TestFunctionSyncFlow(TestCase): + def create_function_sync_flow(self): + sync_flow = FunctionSyncFlow( + "Function1", + build_context=MagicMock(), + deploy_context=MagicMock(), + physical_id_mapping={}, + stacks=[MagicMock()], + ) + sync_flow.gather_resources = MagicMock() + sync_flow.compare_remote = MagicMock() + sync_flow.sync = MagicMock() + sync_flow._get_resource_api_calls = MagicMock() + return sync_flow + + @patch("samcli.lib.sync.sync_flow.Session") + @patch.multiple(FunctionSyncFlow, __abstractmethods__=set()) + def test_sets_up_clients(self, session_mock): + sync_flow = self.create_function_sync_flow() + sync_flow.set_up() + session_mock.return_value.client.assert_called_once_with("lambda") + sync_flow._lambda_client.get_waiter.assert_called_once_with("function_updated") + + @patch("samcli.lib.sync.flows.function_sync_flow.AliasVersionSyncFlow") + @patch("samcli.lib.sync.sync_flow.Session") + @patch.multiple(FunctionSyncFlow, __abstractmethods__=set()) + def test_gather_dependencies(self, session_mock, alias_version_mock): + sync_flow = self.create_function_sync_flow() + sync_flow.get_physical_id = lambda x: "PhysicalFunction1" + sync_flow._get_resource = lambda x: MagicMock() + + sync_flow.set_up() + result = sync_flow.gather_dependencies() + + sync_flow._lambda_waiter.wait.assert_called_once_with(FunctionName="PhysicalFunction1", WaiterConfig=ANY) + self.assertEqual(result, [alias_version_mock.return_value]) + + @patch.multiple(FunctionSyncFlow, __abstractmethods__=set()) + def test_equality_keys(self): + sync_flow = self.create_function_sync_flow() + self.assertEqual(sync_flow._equality_keys(), "Function1") diff --git a/tests/unit/lib/sync/flows/test_http_api_sync_flow.py b/tests/unit/lib/sync/flows/test_http_api_sync_flow.py new file mode 100644 index 0000000000..0a7745b88c --- /dev/null +++ b/tests/unit/lib/sync/flows/test_http_api_sync_flow.py @@ -0,0 +1,109 @@ +from unittest import TestCase +from unittest.mock import ANY, MagicMock, mock_open, patch + +from samcli.lib.sync.flows.http_api_sync_flow import HttpApiSyncFlow +from samcli.lib.providers.exceptions import MissingLocalDefinition + + +class TestHttpApiSyncFlow(TestCase): + def create_sync_flow(self): + sync_flow = HttpApiSyncFlow( + "Api1", + build_context=MagicMock(), + deploy_context=MagicMock(), + physical_id_mapping={}, + stacks=[MagicMock()], + ) + return sync_flow + + @patch("samcli.lib.sync.sync_flow.Session") + def test_set_up(self, session_mock): + sync_flow = self.create_sync_flow() + sync_flow.set_up() + session_mock.return_value.client.assert_any_call("apigatewayv2") + + @patch("samcli.lib.sync.sync_flow.Session") + def test_sync_direct(self, session_mock): + sync_flow = self.create_sync_flow() + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalApi1" + + sync_flow._get_definition_file = MagicMock() + sync_flow._get_definition_file.return_value = "file.yaml" + + sync_flow.set_up() + with patch("builtins.open", mock_open(read_data='{"key": "value"}'.encode("utf-8"))) as mock_file: + sync_flow.gather_resources() + + sync_flow._api_client.reimport_api.return_value = {"Response": "success"} + + sync_flow.sync() + + sync_flow._api_client.reimport_api.assert_called_once_with( + ApiId="PhysicalApi1", Body='{"key": "value"}'.encode("utf-8") + ) + + @patch("samcli.lib.sync.flows.generic_api_sync_flow.get_resource_by_id") + def test_get_definition_file(self, get_resource_mock): + sync_flow = self.create_sync_flow() + + get_resource_mock.return_value = {"Properties": {"DefinitionUri": "test_uri"}} + result_uri = sync_flow._get_definition_file("test") + + self.assertEqual(result_uri, "test_uri") + + get_resource_mock.return_value = {"Properties": {}} + result_uri = sync_flow._get_definition_file("test") + + self.assertEqual(result_uri, None) + + def test_process_definition_file(self): + sync_flow = self.create_sync_flow() + sync_flow._definition_uri = "path" + with patch("builtins.open", mock_open(read_data='{"key": "value"}'.encode("utf-8"))) as mock_file: + data = sync_flow._process_definition_file() + self.assertEqual(data, '{"key": "value"}'.encode("utf-8")) + + @patch("samcli.lib.sync.sync_flow.Session") + def test_failed_gather_resources(self, session_mock): + sync_flow = self.create_sync_flow() + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalApi1" + + sync_flow._get_definition_file = MagicMock() + sync_flow._get_definition_file.return_value = "file.yaml" + + sync_flow.set_up() + sync_flow._definition_uri = None + + with patch("builtins.open", mock_open(read_data='{"key": "value"}'.encode("utf-8"))) as mock_file: + with self.assertRaises(MissingLocalDefinition): + sync_flow.sync() + + def test_compare_remote(self): + sync_flow = self.create_sync_flow() + self.assertFalse(sync_flow.compare_remote()) + + def test_gather_dependencies(self): + sync_flow = self.create_sync_flow() + self.assertEqual(sync_flow.gather_dependencies(), []) + + def test_equality_keys(self): + sync_flow = self.create_sync_flow() + self.assertEqual(sync_flow._equality_keys(), sync_flow._api_identifier) + + def test_get_resource_api_calls(self): + sync_flow = self.create_sync_flow() + self.assertEqual(sync_flow._get_resource_api_calls(), []) + + @patch("samcli.lib.sync.flows.generic_api_sync_flow.get_resource_by_id") + def test_gather_with_no_definition_uri_and_swagger(self, patched_get_resource_by_id): + patched_get_resource_by_id.return_value = None + + sync_flow = self.create_sync_flow() + sync_flow.gather_resources() + + self.assertIsNone(sync_flow._definition_uri) + self.assertIsNone(sync_flow._swagger_body) diff --git a/tests/unit/lib/sync/flows/test_image_function_sync_flow.py b/tests/unit/lib/sync/flows/test_image_function_sync_flow.py new file mode 100644 index 0000000000..de81277797 --- /dev/null +++ b/tests/unit/lib/sync/flows/test_image_function_sync_flow.py @@ -0,0 +1,124 @@ +from samcli.lib.sync.sync_flow import SyncFlow +from unittest import TestCase +from unittest.mock import ANY, MagicMock, call, patch + +from samcli.lib.sync.flows.image_function_sync_flow import ImageFunctionSyncFlow + + +class TestImageFunctionSyncFlow(TestCase): + def create_function_sync_flow(self): + sync_flow = ImageFunctionSyncFlow( + "Function1", + build_context=MagicMock(), + deploy_context=MagicMock(), + physical_id_mapping={}, + stacks=[MagicMock()], + docker_client=MagicMock(), + ) + return sync_flow + + @patch("samcli.lib.sync.sync_flow.Session") + def test_set_up(self, session_mock): + sync_flow = self.create_function_sync_flow() + sync_flow.set_up() + session_mock.return_value.client.assert_any_call("lambda") + session_mock.return_value.client.assert_any_call("ecr") + + @patch("samcli.lib.sync.flows.image_function_sync_flow.ApplicationBuilder") + @patch("samcli.lib.sync.sync_flow.Session") + def test_gather_resources(self, session_mock, builder_mock): + get_mock = MagicMock() + get_mock.return_value = "ImageName1" + builder_mock.return_value.build.return_value.artifacts.get = get_mock + sync_flow = self.create_function_sync_flow() + + sync_flow.set_up() + sync_flow.gather_resources() + + get_mock.assert_called_once_with("Function1") + self.assertEqual(sync_flow._image_name, "ImageName1") + + @patch("samcli.lib.sync.flows.image_function_sync_flow.ECRUploader") + @patch("samcli.lib.sync.sync_flow.Session") + def test_sync_context_image_repo(self, session_mock, uploader_mock): + sync_flow = self.create_function_sync_flow() + sync_flow._image_name = "ImageName1" + + uploader_mock.return_value.upload.return_value = "image_uri" + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalFunction1" + sync_flow._deploy_context.image_repository = "repo_uri" + + sync_flow.set_up() + sync_flow.sync() + + uploader_mock.return_value.upload.assert_called_once_with("ImageName1", "Function1") + uploader_mock.assert_called_once_with(sync_flow._docker_client, sync_flow._ecr_client, "repo_uri", None) + sync_flow._lambda_client.update_function_code.assert_called_once_with( + FunctionName="PhysicalFunction1", ImageUri="image_uri" + ) + + @patch("samcli.lib.sync.flows.image_function_sync_flow.ECRUploader") + @patch("samcli.lib.sync.sync_flow.Session") + def test_sync_context_image_repos(self, session_mock, uploader_mock): + sync_flow = self.create_function_sync_flow() + sync_flow._image_name = "ImageName1" + + uploader_mock.return_value.upload.return_value = "image_uri" + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalFunction1" + sync_flow._deploy_context.image_repository = "" + sync_flow._deploy_context.image_repositories = {"Function1": "repo_uri"} + + sync_flow.set_up() + sync_flow.sync() + + uploader_mock.return_value.upload.assert_called_once_with("ImageName1", "Function1") + uploader_mock.assert_called_once_with(sync_flow._docker_client, sync_flow._ecr_client, "repo_uri", None) + sync_flow._lambda_client.update_function_code.assert_called_once_with( + FunctionName="PhysicalFunction1", ImageUri="image_uri" + ) + + @patch("samcli.lib.sync.flows.image_function_sync_flow.ECRUploader") + @patch("samcli.lib.sync.sync_flow.Session") + def test_sync_remote_image_repo(self, session_mock, uploader_mock): + sync_flow = self.create_function_sync_flow() + sync_flow._image_name = "ImageName1" + + uploader_mock.return_value.upload.return_value = "image_uri" + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalFunction1" + sync_flow._deploy_context.image_repository = "" + sync_flow._deploy_context.image_repositories = {} + + sync_flow.set_up() + + sync_flow._lambda_client.get_function = MagicMock() + sync_flow._lambda_client.get_function.return_value = {"Code": {"ImageUri": "repo_uri:tag"}} + + sync_flow.sync() + + uploader_mock.return_value.upload.assert_called_once_with("ImageName1", "Function1") + uploader_mock.assert_called_once_with(sync_flow._docker_client, sync_flow._ecr_client, "repo_uri", None) + sync_flow._lambda_client.update_function_code.assert_called_once_with( + FunctionName="PhysicalFunction1", ImageUri="image_uri" + ) + + @patch("samcli.lib.sync.flows.image_function_sync_flow.ECRUploader") + @patch("samcli.lib.sync.sync_flow.Session") + def test_sync_with_no_image(self, session_mock, uploader_mock): + sync_flow = self.create_function_sync_flow() + sync_flow._image_name = None + sync_flow.sync() + uploader_mock.return_value.upload.assert_not_called() + + def test_compare_remote(self): + sync_flow = self.create_function_sync_flow() + self.assertFalse(sync_flow.compare_remote()) + + def test_get_resource_api_calls(self): + sync_flow = self.create_function_sync_flow() + self.assertEqual(sync_flow._get_resource_api_calls(), []) diff --git a/tests/unit/lib/sync/flows/test_layer_sync_flow.py b/tests/unit/lib/sync/flows/test_layer_sync_flow.py new file mode 100644 index 0000000000..551fa54842 --- /dev/null +++ b/tests/unit/lib/sync/flows/test_layer_sync_flow.py @@ -0,0 +1,430 @@ +import base64 +import hashlib +from unittest import TestCase +from unittest.mock import MagicMock, Mock, patch, call, ANY, mock_open, PropertyMock + +from parameterized import parameterized + +from samcli.lib.sync.exceptions import MissingPhysicalResourceError, NoLayerVersionsFoundError +from samcli.lib.sync.flows.layer_sync_flow import LayerSyncFlow, FunctionLayerReferenceSync +from samcli.lib.sync.sync_flow import SyncFlow + + +class TestLayerSyncFlow(TestCase): + def setUp(self): + self.layer_identifier = "LayerA" + self.build_context_mock = Mock() + self.deploy_context_mock = Mock() + + self.layer_sync_flow = LayerSyncFlow( + self.layer_identifier, + self.build_context_mock, + self.deploy_context_mock, + {self.layer_identifier: "layer_version_arn"}, + [], + ) + + def test_setup(self): + with patch.object(self.layer_sync_flow, "_session") as patched_session: + with patch.object(SyncFlow, "set_up") as patched_super_setup: + self.layer_sync_flow.set_up() + + patched_super_setup.assert_called_once() + patched_session.assert_has_calls( + [ + call.client("lambda"), + ] + ) + + @patch("samcli.lib.sync.flows.layer_sync_flow.get_resource_by_id") + def test_setup_with_serverless_layer(self, get_resource_by_id_mock): + given_layer_name_with_hashes = f"{self.layer_identifier}abcdefghij" + self.layer_sync_flow._physical_id_mapping = {given_layer_name_with_hashes: "layer_version_arn"} + get_resource_by_id_mock.return_value = False + with patch.object(self.layer_sync_flow, "_session") as patched_session: + with patch.object(SyncFlow, "set_up") as patched_super_setup: + self.layer_sync_flow.set_up() + + patched_super_setup.assert_called_once() + patched_session.assert_has_calls( + [ + call.client("lambda"), + ] + ) + + self.assertEqual(self.layer_sync_flow._layer_arn, "layer_version_arn") + + def test_setup_with_unknown_layer(self): + given_layer_name_with_hashes = f"SomeOtherLayerabcdefghij" + self.layer_sync_flow._physical_id_mapping = {given_layer_name_with_hashes: "layer_version_arn"} + with patch.object(self.layer_sync_flow, "_session") as _: + with patch.object(SyncFlow, "set_up") as _: + with self.assertRaises(MissingPhysicalResourceError): + self.layer_sync_flow.set_up() + + @patch("samcli.lib.sync.flows.layer_sync_flow.ApplicationBuilder") + @patch("samcli.lib.sync.flows.layer_sync_flow.tempfile") + @patch("samcli.lib.sync.flows.layer_sync_flow.make_zip") + @patch("samcli.lib.sync.flows.layer_sync_flow.file_checksum") + @patch("samcli.lib.sync.flows.layer_sync_flow.os") + def test_setup_gather_resources( + self, patched_os, patched_file_checksum, patched_make_zip, patched_tempfile, patched_app_builder + ): + given_collect_build_resources = Mock() + self.build_context_mock.collect_build_resources.return_value = given_collect_build_resources + + given_app_builder = Mock() + given_artifact_folder = Mock() + given_app_builder.build().artifacts.get.return_value = given_artifact_folder + patched_app_builder.return_value = given_app_builder + + given_zip_location = Mock() + patched_make_zip.return_value = given_zip_location + + given_file_checksum = Mock() + patched_file_checksum.return_value = given_file_checksum + + self.layer_sync_flow._get_lock_chain = MagicMock() + + self.layer_sync_flow.gather_resources() + + self.build_context_mock.collect_build_resources.assert_called_with(self.layer_identifier) + + patched_app_builder.assert_called_with( + given_collect_build_resources, + self.build_context_mock.build_dir, + self.build_context_mock.base_dir, + self.build_context_mock.cache_dir, + cached=True, + is_building_specific_resource=True, + manifest_path_override=self.build_context_mock.manifest_path_override, + container_manager=self.build_context_mock.container_manager, + mode=self.build_context_mock.mode, + ) + + patched_tempfile.gettempdir.assert_called_once() + patched_os.path.join.assert_called_with(ANY, ANY) + patched_make_zip.assert_called_with(ANY, self.layer_sync_flow._artifact_folder) + + patched_file_checksum.assert_called_with(ANY, ANY) + + self.assertEqual(self.layer_sync_flow._artifact_folder, given_artifact_folder) + self.assertEqual(self.layer_sync_flow._zip_file, given_zip_location) + self.assertEqual(self.layer_sync_flow._local_sha, given_file_checksum) + + self.layer_sync_flow._get_lock_chain.assert_called_once() + self.layer_sync_flow._get_lock_chain.return_value.__enter__.assert_called_once() + self.layer_sync_flow._get_lock_chain.return_value.__exit__.assert_called_once() + + def test_compare_remote(self): + given_lambda_client = Mock() + self.layer_sync_flow._lambda_client = given_lambda_client + + given_sha256 = base64.b64encode(b"checksum") + given_layer_info = {"Content": {"CodeSha256": given_sha256}} + given_lambda_client.get_layer_version.return_value = given_layer_info + + self.layer_sync_flow._local_sha = base64.b64decode(given_sha256).hex() + + with patch.object(self.layer_sync_flow, "_get_latest_layer_version") as patched_get_latest_layer_version: + given_layer_name = Mock() + given_latest_layer_version = Mock() + self.layer_sync_flow._layer_arn = given_layer_name + patched_get_latest_layer_version.return_value = given_latest_layer_version + + compare_result = self.layer_sync_flow.compare_remote() + + self.assertTrue(compare_result) + + def test_sync(self): + with patch.object(self.layer_sync_flow, "_publish_new_layer_version") as patched_publish_new_layer_version: + with patch.object(self.layer_sync_flow, "_delete_old_layer_version") as patched_delete_old_layer_version: + given_layer_version = Mock() + patched_publish_new_layer_version.return_value = given_layer_version + + self.layer_sync_flow.sync() + self.assertEqual(self.layer_sync_flow._new_layer_version, given_layer_version) + + patched_publish_new_layer_version.assert_called_once() + patched_delete_old_layer_version.assert_called_once() + + def test_publish_new_layer_version(self): + given_layer_name = Mock() + + given_lambda_client = Mock() + self.layer_sync_flow._lambda_client = given_lambda_client + + given_zip_file = Mock() + self.layer_sync_flow._zip_file = given_zip_file + + self.layer_sync_flow._layer_arn = given_layer_name + + with patch.object(self.layer_sync_flow, "_get_resource") as patched_get_resource: + with patch("builtins.open", mock_open(read_data="data")) as mock_file: + given_publish_layer_result = {"Version": 24} + given_lambda_client.publish_layer_version.return_value = given_publish_layer_result + + given_layer_resource = Mock() + patched_get_resource.return_value = given_layer_resource + + result_version = self.layer_sync_flow._publish_new_layer_version() + + patched_get_resource.assert_called_with(self.layer_identifier) + given_lambda_client.publish_layer_version.assert_called_with( + LayerName=given_layer_name, + Content={"ZipFile": "data"}, + CompatibleRuntimes=given_layer_resource.get("Properties", {}).get("CompatibleRuntimes", []), + ) + + self.assertEqual(result_version, given_publish_layer_result.get("Version")) + + def test_delete_old_layer_version(self): + given_layer_name = Mock() + given_layer_version = Mock() + + given_lambda_client = Mock() + self.layer_sync_flow._lambda_client = given_lambda_client + + self.layer_sync_flow._layer_arn = given_layer_name + self.layer_sync_flow._old_layer_version = given_layer_version + + self.layer_sync_flow._delete_old_layer_version() + + given_lambda_client.delete_layer_version.assert_called_with( + LayerName=given_layer_name, VersionNumber=given_layer_version + ) + + @patch("samcli.lib.sync.flows.layer_sync_flow.os") + @patch("samcli.lib.sync.flows.layer_sync_flow.SamFunctionProvider") + @patch("samcli.lib.sync.flows.layer_sync_flow.FunctionLayerReferenceSync") + def test_gather_dependencies(self, patched_function_ref_sync, patched_function_provider, os_mock): + self.layer_sync_flow._new_layer_version = "given_new_layer_version_arn" + + given_function_provider = Mock() + patched_function_provider.return_value = given_function_provider + + mock_some_random_layer = PropertyMock() + mock_some_random_layer.full_path = "SomeRandomLayer" + + mock_given_layer = PropertyMock() + mock_given_layer.full_path = self.layer_identifier + + mock_some_nested_layer = PropertyMock() + mock_some_nested_layer.full_path = "NestedStack1/" + self.layer_identifier + + mock_function_a = PropertyMock(layers=[mock_some_random_layer]) + mock_function_a.full_path = "FunctionA" + + mock_function_b = PropertyMock(layers=[mock_given_layer]) + mock_function_b.full_path = "FunctionB" + + mock_function_c = PropertyMock(layers=[mock_some_nested_layer]) + mock_function_c.full_path = "NestedStack1/FunctionC" + + given_layers = [ + mock_function_a, + mock_function_b, + mock_function_c, + ] + given_function_provider.get_all.return_value = given_layers + + self.layer_sync_flow._stacks = Mock() + + given_layer_physical_name = Mock() + self.layer_sync_flow._layer_arn = given_layer_physical_name + + self.layer_sync_flow._zip_file = Mock() + + dependencies = self.layer_sync_flow.gather_dependencies() + + patched_function_ref_sync.assert_called_once_with( + "FunctionB", + given_layer_physical_name, + self.layer_sync_flow._new_layer_version, + self.layer_sync_flow._build_context, + self.layer_sync_flow._deploy_context, + self.layer_sync_flow._physical_id_mapping, + self.layer_sync_flow._stacks, + ) + + self.assertEqual(len(dependencies), 1) + + @patch("samcli.lib.sync.flows.layer_sync_flow.os") + @patch("samcli.lib.sync.flows.layer_sync_flow.SamFunctionProvider") + @patch("samcli.lib.sync.flows.layer_sync_flow.FunctionLayerReferenceSync") + def test_gather_dependencies_nested_stack(self, patched_function_ref_sync, patched_function_provider, os_mock): + self.layer_identifier = "NestedStack1/Layer1" + self.layer_sync_flow._layer_identifier = "NestedStack1/Layer1" + self.layer_sync_flow._new_layer_version = "given_new_layer_version_arn" + + given_function_provider = Mock() + patched_function_provider.return_value = given_function_provider + + mock_some_random_layer = PropertyMock() + mock_some_random_layer.full_path = "Layer1" + + mock_given_layer = PropertyMock() + mock_given_layer.full_path = self.layer_identifier + + mock_some_nested_layer = PropertyMock() + mock_some_nested_layer.full_path = "NestedStack1/Layer2" + + mock_function_a = PropertyMock(layers=[mock_some_random_layer]) + mock_function_a.full_path = "FunctionA" + + mock_function_b = PropertyMock(layers=[mock_given_layer]) + mock_function_b.full_path = "NestedStack1/FunctionB" + + mock_function_c = PropertyMock(layers=[mock_some_nested_layer]) + mock_function_c.full_path = "NestedStack1/FunctionC" + + given_layers = [ + mock_function_a, + mock_function_b, + mock_function_c, + ] + given_function_provider.get_all.return_value = given_layers + + self.layer_sync_flow._stacks = Mock() + + given_layer_physical_name = Mock() + self.layer_sync_flow._layer_arn = given_layer_physical_name + + self.layer_sync_flow._zip_file = Mock() + + dependencies = self.layer_sync_flow.gather_dependencies() + + patched_function_ref_sync.assert_called_once_with( + "NestedStack1/FunctionB", + given_layer_physical_name, + self.layer_sync_flow._new_layer_version, + self.layer_sync_flow._build_context, + self.layer_sync_flow._deploy_context, + self.layer_sync_flow._physical_id_mapping, + self.layer_sync_flow._stacks, + ) + + self.assertEqual(len(dependencies), 1) + + def test_get_latest_layer_version(self): + given_version = Mock() + given_layer_name = Mock() + given_lambda_client = Mock() + given_lambda_client.list_layer_versions.return_value = {"LayerVersions": [{"Version": given_version}]} + self.layer_sync_flow._lambda_client = given_lambda_client + self.layer_sync_flow._layer_arn = given_layer_name + + latest_layer_version = self.layer_sync_flow._get_latest_layer_version() + + given_lambda_client.list_layer_versions.assert_called_with(LayerName=given_layer_name) + self.assertEqual(latest_layer_version, given_version) + + def test_get_latest_layer_version_error(self): + given_layer_name = Mock() + given_lambda_client = Mock() + given_lambda_client.list_layer_versions.return_value = {"LayerVersions": []} + self.layer_sync_flow._lambda_client = given_lambda_client + self.layer_sync_flow._layer_arn = given_layer_name + + with self.assertRaises(NoLayerVersionsFoundError): + self.layer_sync_flow._get_latest_layer_version() + + def test_equality_keys(self): + self.assertEqual(self.layer_sync_flow._equality_keys(), self.layer_identifier) + + @patch("samcli.lib.sync.flows.layer_sync_flow.ResourceAPICall") + def test_get_resource_api_calls(self, resource_api_call_mock): + result = self.layer_sync_flow._get_resource_api_calls() + self.assertEqual(len(result), 1) + resource_api_call_mock.assert_called_once_with(self.layer_identifier, ["Build"]) + + +class TestFunctionLayerReferenceSync(TestCase): + def setUp(self): + self.function_identifier = "function" + self.layer_name = "Layer1" + self.old_layer_version = 1 + self.new_layer_version = 2 + + self.function_layer_sync = FunctionLayerReferenceSync( + self.function_identifier, self.layer_name, self.new_layer_version, Mock(), Mock(), {}, [] + ) + + def test_setup(self): + with patch.object(self.function_layer_sync, "_session") as patched_session: + with patch.object(SyncFlow, "set_up") as patched_super_setup: + self.function_layer_sync.set_up() + + patched_super_setup.assert_called_once() + patched_session.assert_has_calls( + [ + call.client("lambda"), + ] + ) + + def test_sync(self): + given_lambda_client = Mock() + self.function_layer_sync._lambda_client = given_lambda_client + + other_layer_version_arn = "SomeOtherLayerVersionArn" + given_function_result = {"Configuration": {"Layers": [{"Arn": "Layer1:1"}, {"Arn": other_layer_version_arn}]}} + given_lambda_client.get_function.return_value = given_function_result + + with patch.object(self.function_layer_sync, "get_physical_id") as patched_get_physical_id: + with patch.object(self.function_layer_sync, "_locks") as patched_locks: + given_physical_id = Mock() + patched_get_physical_id.return_value = given_physical_id + + self.function_layer_sync.sync() + + patched_get_physical_id.assert_called_with(self.function_identifier) + + patched_locks.get.assert_called_with( + SyncFlow._get_lock_key( + self.function_identifier, FunctionLayerReferenceSync.UPDATE_FUNCTION_CONFIGURATION + ) + ) + + given_lambda_client.get_function.assert_called_with(FunctionName=given_physical_id) + + given_lambda_client.update_function_configuration.assert_called_with( + FunctionName=given_physical_id, Layers=[other_layer_version_arn, "Layer1:2"] + ) + + def test_sync_with_existing_new_layer_version_arn(self): + given_lambda_client = Mock() + self.function_layer_sync._lambda_client = given_lambda_client + + given_function_result = {"Configuration": {"Layers": [{"Arn": "Layer1:2"}]}} + given_lambda_client.get_function.return_value = given_function_result + + with patch.object(self.function_layer_sync, "get_physical_id") as patched_get_physical_id: + with patch.object(self.function_layer_sync, "_locks") as patched_locks: + given_physical_id = Mock() + patched_get_physical_id.return_value = given_physical_id + + self.function_layer_sync.sync() + + patched_locks.get.assert_called_with( + SyncFlow._get_lock_key( + self.function_identifier, FunctionLayerReferenceSync.UPDATE_FUNCTION_CONFIGURATION + ) + ) + + patched_get_physical_id.assert_called_with(self.function_identifier) + + given_lambda_client.get_function.assert_called_with(FunctionName=given_physical_id) + + given_lambda_client.update_function_configuration.assert_not_called() + + def test_equality_keys(self): + self.assertEqual( + self.function_layer_sync._equality_keys(), + (self.function_identifier, self.layer_name, self.new_layer_version), + ) + + def test_compare_remote(self): + self.assertFalse(self.function_layer_sync.compare_remote()) + + def test_gather_dependencies(self): + self.assertEqual(self.function_layer_sync.gather_dependencies(), []) diff --git a/tests/unit/lib/sync/flows/test_rest_api_sync_flow.py b/tests/unit/lib/sync/flows/test_rest_api_sync_flow.py new file mode 100644 index 0000000000..c0f7c4fe9b --- /dev/null +++ b/tests/unit/lib/sync/flows/test_rest_api_sync_flow.py @@ -0,0 +1,100 @@ +from abc import abstractmethod, ABC +from unittest import TestCase +from unittest.mock import ANY, MagicMock, mock_open, patch + +from samcli.lib.sync.flows.rest_api_sync_flow import RestApiSyncFlow +from samcli.lib.providers.exceptions import MissingLocalDefinition + + +class TestRestApiSyncFlow(TestCase): + def create_sync_flow(self): + sync_flow = RestApiSyncFlow( + "Api1", + build_context=MagicMock(), + deploy_context=MagicMock(), + physical_id_mapping={}, + stacks=[MagicMock()], + ) + return sync_flow + + @patch("samcli.lib.sync.sync_flow.Session") + def test_set_up(self, session_mock): + sync_flow = self.create_sync_flow() + sync_flow.set_up() + session_mock.return_value.client.assert_any_call("apigateway") + + @patch("samcli.lib.sync.sync_flow.Session") + def test_sync_direct(self, session_mock): + sync_flow = self.create_sync_flow() + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalApi1" + + sync_flow._get_definition_file = MagicMock() + sync_flow._get_definition_file.return_value = "file.yaml" + + sync_flow.set_up() + with patch("builtins.open", mock_open(read_data='{"key": "value"}'.encode("utf-8"))) as mock_file: + sync_flow.gather_resources() + + sync_flow._api_client.put_rest_api.return_value = {"Response": "success"} + + sync_flow.sync() + + sync_flow._api_client.put_rest_api.assert_called_once_with( + restApiId="PhysicalApi1", mode="overwrite", body='{"key": "value"}'.encode("utf-8") + ) + + @patch("samcli.lib.sync.flows.generic_api_sync_flow.get_resource_by_id") + def test_get_definition_file(self, get_resource_mock): + sync_flow = self.create_sync_flow() + + get_resource_mock.return_value = {"Properties": {"DefinitionUri": "test_uri"}} + result_uri = sync_flow._get_definition_file("test") + + self.assertEqual(result_uri, "test_uri") + + get_resource_mock.return_value = {"Properties": {}} + result_uri = sync_flow._get_definition_file("test") + + self.assertEqual(result_uri, None) + + def test_process_definition_file(self): + sync_flow = self.create_sync_flow() + sync_flow._definition_uri = "path" + with patch("builtins.open", mock_open(read_data='{"key": "value"}'.encode("utf-8"))) as mock_file: + data = sync_flow._process_definition_file() + self.assertEqual(data, '{"key": "value"}'.encode("utf-8")) + + @patch("samcli.lib.sync.sync_flow.Session") + def test_failed_gather_resources(self, session_mock): + sync_flow = self.create_sync_flow() + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalApi1" + + sync_flow._get_definition_file = MagicMock() + sync_flow._get_definition_file.return_value = "file.yaml" + + sync_flow.set_up() + sync_flow._definition_uri = None + + with patch("builtins.open", mock_open(read_data='{"key": "value"}'.encode("utf-8"))) as mock_file: + with self.assertRaises(MissingLocalDefinition): + sync_flow.sync() + + def test_compare_remote(self): + sync_flow = self.create_sync_flow() + self.assertFalse(sync_flow.compare_remote()) + + def test_gather_dependencies(self): + sync_flow = self.create_sync_flow() + self.assertEqual(sync_flow.gather_dependencies(), []) + + def test_equality_keys(self): + sync_flow = self.create_sync_flow() + self.assertEqual(sync_flow._equality_keys(), sync_flow._api_identifier) + + def test_get_resource_api_calls(self): + sync_flow = self.create_sync_flow() + self.assertEqual(sync_flow._get_resource_api_calls(), []) diff --git a/tests/unit/lib/sync/flows/test_stepfunctions_sync_flow.py b/tests/unit/lib/sync/flows/test_stepfunctions_sync_flow.py new file mode 100644 index 0000000000..eeaaa85c7d --- /dev/null +++ b/tests/unit/lib/sync/flows/test_stepfunctions_sync_flow.py @@ -0,0 +1,125 @@ +from samcli.lib.providers.exceptions import MissingLocalDefinition +from unittest import TestCase +from unittest.mock import ANY, MagicMock, mock_open, patch + +from samcli.lib.sync.flows.stepfunctions_sync_flow import StepFunctionsSyncFlow +from samcli.lib.sync.exceptions import InfraSyncRequiredError + + +class TestStepFunctionsSyncFlow(TestCase): + def setUp(self) -> None: + get_resource_patch = patch("samcli.lib.sync.flows.stepfunctions_sync_flow.get_resource_by_id") + self.get_resource_mock = get_resource_patch.start() + self.get_resource_mock.return_value = {"Properties": {"DefinitionUri": "test_uri"}} + self.addCleanup(get_resource_patch.stop) + + def create_sync_flow(self): + patch("samcli.lib.sync.flows.stepfunctions_sync_flow.get_resource_by_id") + sync_flow = StepFunctionsSyncFlow( + "StateMachine1", + build_context=MagicMock(), + deploy_context=MagicMock(), + physical_id_mapping={}, + stacks=[MagicMock()], + ) + return sync_flow + + @patch("samcli.lib.sync.sync_flow.Session") + def test_set_up(self, session_mock): + sync_flow = self.create_sync_flow() + sync_flow.set_up() + session_mock.return_value.client.assert_any_call("stepfunctions") + + @patch("samcli.lib.sync.sync_flow.Session") + def test_sync_direct(self, session_mock): + sync_flow = self.create_sync_flow() + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalId1" + + sync_flow._get_definition_file = MagicMock() + sync_flow._get_definition_file.return_value = "file.yaml" + + sync_flow.set_up() + with patch("builtins.open", mock_open(read_data='{"key": "value"}')) as mock_file: + sync_flow.gather_resources() + + sync_flow._stepfunctions_client.update_state_machine.return_value = {"Response": "success"} + + sync_flow.sync() + + sync_flow._stepfunctions_client.update_state_machine.assert_called_once_with( + stateMachineArn="PhysicalId1", definition='{"key": "value"}' + ) + + @patch("samcli.lib.sync.flows.stepfunctions_sync_flow.get_resource_by_id") + def test_get_definition_file(self, get_resource_mock): + sync_flow = self.create_sync_flow() + + sync_flow._resource = {"Properties": {"DefinitionUri": "test_uri"}} + result_uri = sync_flow._get_definition_file("test") + + self.assertEqual(result_uri, "test_uri") + + sync_flow._resource = {"Properties": {}} + result_uri = sync_flow._get_definition_file("test") + + self.assertEqual(result_uri, None) + + def test_process_definition_file(self): + sync_flow = self.create_sync_flow() + sync_flow._definition_uri = "path" + with patch("builtins.open", mock_open(read_data='{"key": "value"}')) as mock_file: + data = sync_flow._process_definition_file() + self.assertEqual(data, '{"key": "value"}') + + @patch("samcli.lib.sync.sync_flow.Session") + def test_failed_gather_resources_definition_substitution(self, session_mock): + self.get_resource_mock.return_value = {"Properties": {"DefinitionSubstitutions": {"a": "b"}}} + sync_flow = self.create_sync_flow() + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalApi1" + + sync_flow._get_definition_file = MagicMock() + sync_flow._get_definition_file.return_value = "file.yaml" + + sync_flow.set_up() + sync_flow._definition_uri = None + + with patch("builtins.open", mock_open(read_data='{"key": "value"}')) as mock_file: + with self.assertRaises(InfraSyncRequiredError): + sync_flow.gather_resources() + + @patch("samcli.lib.sync.sync_flow.Session") + def test_failed_gather_resources(self, session_mock): + sync_flow = self.create_sync_flow() + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalApi1" + + sync_flow._get_definition_file = MagicMock() + sync_flow._get_definition_file.return_value = "file.yaml" + + sync_flow.set_up() + sync_flow._definition_uri = None + + with patch("builtins.open", mock_open(read_data='{"key": "value"}')) as mock_file: + with self.assertRaises(MissingLocalDefinition): + sync_flow.sync() + + def test_gather_dependencies(self): + sync_flow = self.create_sync_flow() + self.assertEqual(sync_flow.gather_dependencies(), []) + + def test_compare_remote(self): + sync_flow = self.create_sync_flow() + self.assertFalse(sync_flow.compare_remote()) + + def test_get_resource_api_calls(self): + sync_flow = self.create_sync_flow() + self.assertEqual(sync_flow._get_resource_api_calls(), []) + + def test_equality_keys(self): + sync_flow = self.create_sync_flow() + self.assertEqual(sync_flow._equality_keys(), sync_flow._state_machine_identifier) diff --git a/tests/unit/lib/sync/flows/test_zip_function_sync_flow.py b/tests/unit/lib/sync/flows/test_zip_function_sync_flow.py new file mode 100644 index 0000000000..92b0643be7 --- /dev/null +++ b/tests/unit/lib/sync/flows/test_zip_function_sync_flow.py @@ -0,0 +1,180 @@ +import os +import hashlib + +from samcli.lib.sync.sync_flow import SyncFlow +from unittest import TestCase +from unittest.mock import ANY, MagicMock, call, mock_open, patch + +from samcli.lib.sync.flows.zip_function_sync_flow import ZipFunctionSyncFlow + + +class TestZipFunctionSyncFlow(TestCase): + def create_function_sync_flow(self): + sync_flow = ZipFunctionSyncFlow( + "Function1", + build_context=MagicMock(), + deploy_context=MagicMock(), + physical_id_mapping={}, + stacks=[MagicMock()], + ) + sync_flow._get_resource_api_calls = MagicMock() + return sync_flow + + @patch("samcli.lib.sync.sync_flow.Session") + def test_set_up(self, session_mock): + sync_flow = self.create_function_sync_flow() + sync_flow.set_up() + session_mock.return_value.client.assert_any_call("lambda") + session_mock.return_value.client.assert_any_call("s3") + + @patch("samcli.lib.sync.flows.zip_function_sync_flow.hashlib.sha256") + @patch("samcli.lib.sync.flows.zip_function_sync_flow.uuid.uuid4") + @patch("samcli.lib.sync.flows.zip_function_sync_flow.file_checksum") + @patch("samcli.lib.sync.flows.zip_function_sync_flow.make_zip") + @patch("samcli.lib.sync.flows.zip_function_sync_flow.tempfile.gettempdir") + @patch("samcli.lib.sync.flows.zip_function_sync_flow.ApplicationBuilder") + @patch("samcli.lib.sync.sync_flow.Session") + def test_gather_resources( + self, session_mock, builder_mock, gettempdir_mock, make_zip_mock, file_checksum_mock, uuid4_mock, sha256_mock + ): + get_mock = MagicMock() + get_mock.return_value = "ArtifactFolder1" + builder_mock.return_value.build.return_value.artifacts.get = get_mock + uuid4_mock.return_value.hex = "uuid_value" + gettempdir_mock.return_value = "temp_folder" + make_zip_mock.return_value = "zip_file" + file_checksum_mock.return_value = "sha256_value" + sync_flow = self.create_function_sync_flow() + + sync_flow._get_lock_chain = MagicMock() + + sync_flow.set_up() + sync_flow.gather_resources() + + get_mock.assert_called_once_with("Function1") + self.assertEqual(sync_flow._artifact_folder, "ArtifactFolder1") + make_zip_mock.assert_called_once_with("temp_folder" + os.sep + "data-uuid_value", "ArtifactFolder1") + file_checksum_mock.assert_called_once_with("zip_file", sha256_mock.return_value) + self.assertEqual("sha256_value", sync_flow._local_sha) + sync_flow._get_lock_chain.assert_called_once() + sync_flow._get_lock_chain.return_value.__enter__.assert_called_once() + sync_flow._get_lock_chain.return_value.__exit__.assert_called_once() + + @patch("samcli.lib.sync.flows.zip_function_sync_flow.base64.b64decode") + @patch("samcli.lib.sync.sync_flow.Session") + def test_compare_remote_true(self, session_mock, b64decode_mock): + b64decode_mock.return_value.hex.return_value = "sha256_value" + sync_flow = self.create_function_sync_flow() + sync_flow._local_sha = "sha256_value" + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalFunction1" + + sync_flow.set_up() + + sync_flow._lambda_client.get_function.return_value = {"Configuration": {"CodeSha256": "sha256_value_b64"}} + + result = sync_flow.compare_remote() + + sync_flow._lambda_client.get_function.assert_called_once_with(FunctionName="PhysicalFunction1") + b64decode_mock.assert_called_once_with("sha256_value_b64") + self.assertTrue(result) + + @patch("samcli.lib.sync.flows.zip_function_sync_flow.base64.b64decode") + @patch("samcli.lib.sync.sync_flow.Session") + def test_compare_remote_false(self, session_mock, b64decode_mock): + b64decode_mock.return_value.hex.return_value = "sha256_value_2" + sync_flow = self.create_function_sync_flow() + sync_flow._local_sha = "sha256_value" + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalFunction1" + + sync_flow.set_up() + + sync_flow._lambda_client.get_function.return_value = {"Configuration": {"CodeSha256": "sha256_value_b64"}} + + result = sync_flow.compare_remote() + + sync_flow._lambda_client.get_function.assert_called_once_with(FunctionName="PhysicalFunction1") + b64decode_mock.assert_called_once_with("sha256_value_b64") + self.assertFalse(result) + + @patch("samcli.lib.sync.flows.zip_function_sync_flow.open", mock_open(read_data=b"zip_content"), create=True) + @patch("samcli.lib.sync.flows.zip_function_sync_flow.os.remove") + @patch("samcli.lib.sync.flows.zip_function_sync_flow.os.path.exists") + @patch("samcli.lib.sync.flows.zip_function_sync_flow.S3Uploader") + @patch("samcli.lib.sync.flows.zip_function_sync_flow.os.path.getsize") + @patch("samcli.lib.sync.sync_flow.Session") + def test_sync_direct(self, session_mock, getsize_mock, uploader_mock, exists_mock, remove_mock): + getsize_mock.return_value = 49 * 1024 * 1024 + exists_mock.return_value = True + sync_flow = self.create_function_sync_flow() + sync_flow._zip_file = "zip_file" + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalFunction1" + + sync_flow.set_up() + + sync_flow.sync() + + sync_flow._lambda_client.update_function_code.assert_called_once_with( + FunctionName="PhysicalFunction1", ZipFile=b"zip_content" + ) + remove_mock.assert_called_once_with("zip_file") + + @patch("samcli.lib.sync.flows.zip_function_sync_flow.open", mock_open(read_data=b"zip_content"), create=True) + @patch("samcli.lib.sync.flows.zip_function_sync_flow.os.remove") + @patch("samcli.lib.sync.flows.zip_function_sync_flow.os.path.exists") + @patch("samcli.lib.sync.flows.zip_function_sync_flow.S3Uploader") + @patch("samcli.lib.sync.flows.zip_function_sync_flow.os.path.getsize") + @patch("samcli.lib.sync.sync_flow.Session") + def test_sync_s3(self, session_mock, getsize_mock, uploader_mock, exists_mock, remove_mock): + getsize_mock.return_value = 51 * 1024 * 1024 + exists_mock.return_value = True + uploader_mock.return_value.upload_with_dedup.return_value = "s3://bucket_name/bucket/key" + sync_flow = self.create_function_sync_flow() + sync_flow._zip_file = "zip_file" + sync_flow._deploy_context.s3_bucket = "bucket_name" + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalFunction1" + + sync_flow.set_up() + + sync_flow.sync() + + uploader_mock.return_value.upload_with_dedup.assert_called_once_with("zip_file") + + sync_flow._lambda_client.update_function_code.assert_called_once_with( + FunctionName="PhysicalFunction1", S3Bucket="bucket_name", S3Key="bucket/key" + ) + remove_mock.assert_called_once_with("zip_file") + + @patch("samcli.lib.sync.flows.zip_function_sync_flow.ResourceAPICall") + def test_get_resource_api_calls(self, resource_api_call_mock): + build_context = MagicMock() + layer1 = MagicMock() + layer2 = MagicMock() + layer1.full_path = "Layer1" + layer2.full_path = "Layer2" + function_mock = MagicMock() + function_mock.layers = [layer1, layer2] + build_context.function_provider.functions.get.return_value = function_mock + sync_flow = ZipFunctionSyncFlow( + "Function1", + build_context=build_context, + deploy_context=MagicMock(), + physical_id_mapping={}, + stacks=[MagicMock()], + ) + + result = sync_flow._get_resource_api_calls() + self.assertEqual(len(result), 2) + resource_api_call_mock.assert_any_call("Layer1", ["Build"]) + resource_api_call_mock.assert_any_call("Layer2", ["Build"]) + + def test_combine_dependencies(self): + sync_flow = self.create_function_sync_flow() + self.assertTrue(sync_flow._combine_dependencies()) diff --git a/tests/unit/lib/sync/test_continuous_sync_flow_executor.py b/tests/unit/lib/sync/test_continuous_sync_flow_executor.py new file mode 100644 index 0000000000..d9c526abfe --- /dev/null +++ b/tests/unit/lib/sync/test_continuous_sync_flow_executor.py @@ -0,0 +1,144 @@ +from multiprocessing.managers import ValueProxy +from queue import Queue +from samcli.lib.sync.continuous_sync_flow_executor import ContinuousSyncFlowExecutor, DelayedSyncFlowTask +from samcli.lib.sync.sync_flow import SyncFlow + +from botocore.exceptions import ClientError +from samcli.lib.sync.exceptions import ( + MissingPhysicalResourceError, + NoLayerVersionsFoundError, + SyncFlowException, +) +from unittest import TestCase +from unittest.mock import ANY, MagicMock, call, patch + +from samcli.lib.sync.sync_flow_executor import ( + SyncFlowExecutor, + SyncFlowResult, + SyncFlowTask, + default_exception_handler, + HELP_TEXT_FOR_SYNC_INFRA, +) + + +class TestContinuousSyncFlowExecutor(TestCase): + def setUp(self): + self.thread_pool_executor_patch = patch("samcli.lib.sync.sync_flow_executor.ThreadPoolExecutor") + self.thread_pool_executor_mock = self.thread_pool_executor_patch.start() + self.thread_pool_executor = self.thread_pool_executor_mock.return_value + self.thread_pool_executor.__enter__.return_value = self.thread_pool_executor + self.lock_distributor_patch = patch("samcli.lib.sync.sync_flow_executor.LockDistributor") + self.lock_distributor_mock = self.lock_distributor_patch.start() + self.lock_distributor = self.lock_distributor_mock.return_value + self.executor = ContinuousSyncFlowExecutor() + + def tearDown(self) -> None: + self.thread_pool_executor_patch.stop() + self.lock_distributor_patch.stop() + + @patch("samcli.lib.sync.continuous_sync_flow_executor.time.time") + @patch("samcli.lib.sync.continuous_sync_flow_executor.DelayedSyncFlowTask") + def test_add_delayed_sync_flow(self, task_mock, time_mock): + add_sync_flow_task_mock = MagicMock() + task = MagicMock() + task_mock.return_value = task + time_mock.return_value = 1000 + self.executor._add_sync_flow_task = add_sync_flow_task_mock + sync_flow = MagicMock() + + self.executor.add_delayed_sync_flow(sync_flow, False, 15) + + task_mock.assert_called_once_with(sync_flow, False, 1000, 15) + add_sync_flow_task_mock.assert_called_once_with(task) + + def test_add_sync_flow_task(self): + sync_flow = MagicMock() + task = DelayedSyncFlowTask(sync_flow, False, 1000, 15) + + self.executor._add_sync_flow_task(task) + + sync_flow.set_locks_with_distributor.assert_called_once_with(self.executor._lock_distributor) + + queue_task = self.executor._flow_queue.get() + self.assertEqual(sync_flow, queue_task.sync_flow) + + def test_stop_without_manager(self): + self.executor.stop() + self.assertTrue(self.executor._stop_flag) + + def test_should_stop_without_manager(self): + self.executor._stop_flag = True + self.assertTrue(self.executor.should_stop()) + + @patch("samcli.lib.sync.continuous_sync_flow_executor.time.time") + @patch("samcli.lib.sync.sync_flow_executor.time.sleep") + def test_execute_high_level_logic(self, sleep_mock, time_mock): + exception_handler_mock = MagicMock() + time_mock.return_value = 1001 + + flow1 = MagicMock() + flow2 = MagicMock() + flow3 = MagicMock() + + task1 = DelayedSyncFlowTask(flow1, False, 1000, 0) + task2 = DelayedSyncFlowTask(flow2, False, 1000, 0) + task3 = DelayedSyncFlowTask(flow3, False, 1000, 0) + + result1 = SyncFlowResult(flow1, [flow3]) + + future1 = MagicMock() + future2 = MagicMock() + future3 = MagicMock() + + exception1 = MagicMock(spec=Exception) + sync_flow_exception = MagicMock(spec=SyncFlowException) + sync_flow_exception.sync_flow = flow2 + sync_flow_exception.exception = exception1 + + future1.done.side_effect = [False, False, True] + future1.exception.return_value = None + future1.result.return_value = result1 + + future2.done.side_effect = [False, False, False, True] + future2.exception.return_value = sync_flow_exception + + future3.done.side_effect = [False, False, False, True] + future3.exception.return_value = None + + self.thread_pool_executor.submit = MagicMock() + self.thread_pool_executor.submit.side_effect = [future1, future2, future3] + + self.executor._flow_queue.put(task1) + self.executor._flow_queue.put(task2) + + self.executor.add_sync_flow = MagicMock() + self.executor.add_sync_flow.side_effect = lambda x: self.executor._flow_queue.put(task3) + + self.executor.should_stop = MagicMock() + self.executor.should_stop.side_effect = [ + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + True, + ] + + self.executor.execute(exception_handler=exception_handler_mock) + + self.thread_pool_executor.submit.assert_has_calls( + [ + call(SyncFlowExecutor._sync_flow_execute_wrapper, flow1), + call(SyncFlowExecutor._sync_flow_execute_wrapper, flow2), + call(SyncFlowExecutor._sync_flow_execute_wrapper, flow3), + ] + ) + self.executor.add_sync_flow.assert_called_once_with(flow3) + + exception_handler_mock.assert_called_once_with(sync_flow_exception) + self.assertEqual(len(sleep_mock.mock_calls), 10) diff --git a/tests/unit/lib/sync/test_exceptions.py b/tests/unit/lib/sync/test_exceptions.py new file mode 100644 index 0000000000..08d685f106 --- /dev/null +++ b/tests/unit/lib/sync/test_exceptions.py @@ -0,0 +1,50 @@ +from unittest import TestCase +from unittest.mock import MagicMock +from samcli.lib.sync.exceptions import ( + MissingPhysicalResourceError, + NoLayerVersionsFoundError, + SyncFlowException, + MissingFunctionBuildDefinition, + InvalidRuntimeDefinitionForFunction, +) + + +class TestSyncFlowException(TestCase): + def test_exception(self): + sync_flow_mock = MagicMock() + exception_mock = MagicMock() + exception = SyncFlowException(sync_flow_mock, exception_mock) + self.assertEqual(exception.sync_flow, sync_flow_mock) + self.assertEqual(exception.exception, exception_mock) + + +class TestMissingPhysicalResourceError(TestCase): + def test_exception(self): + exception = MissingPhysicalResourceError("A") + self.assertEqual(exception.resource_identifier, "A") + + def test_exception_with_mapping(self): + physical_mapping = MagicMock() + exception = MissingPhysicalResourceError("A", physical_mapping) + self.assertEqual(exception.resource_identifier, "A") + self.assertEqual(exception.physical_resource_mapping, physical_mapping) + + +class TestNoLayerVersionsFoundError(TestCase): + def test_exception(self): + exception = NoLayerVersionsFoundError("layer_name_arn") + self.assertEqual(exception.layer_name_arn, "layer_name_arn") + + +class TestMissingFunctionBuildDefinition(TestCase): + def test_exception(self): + function_logical_id = "function_logical_id" + exception = MissingFunctionBuildDefinition(function_logical_id) + self.assertEqual(exception.function_logical_id, function_logical_id) + + +class TestInvalidRuntimeDefinitionForFunction(TestCase): + def test_exception(self): + function_logical_id = "function_logical_id" + exception = InvalidRuntimeDefinitionForFunction(function_logical_id) + self.assertEqual(exception.function_logical_id, function_logical_id) diff --git a/tests/unit/lib/sync/test_sync_flow.py b/tests/unit/lib/sync/test_sync_flow.py new file mode 100644 index 0000000000..caca763eb8 --- /dev/null +++ b/tests/unit/lib/sync/test_sync_flow.py @@ -0,0 +1,119 @@ +from samcli.lib.providers.provider import ResourceIdentifier +from unittest import TestCase +from unittest.mock import MagicMock, call, patch + +from samcli.lib.sync.sync_flow import SyncFlow, ResourceAPICall +from samcli.lib.utils.lock_distributor import LockChain + + +class TestSyncFlow(TestCase): + def create_sync_flow(self): + sync_flow = SyncFlow( + build_context=MagicMock(), + deploy_context=MagicMock(), + physical_id_mapping={}, + log_name="log-name", + stacks=[MagicMock()], + ) + sync_flow.gather_resources = MagicMock() + sync_flow.compare_remote = MagicMock() + sync_flow.sync = MagicMock() + sync_flow.gather_dependencies = MagicMock() + sync_flow._get_resource_api_calls = MagicMock() + return sync_flow + + @patch("samcli.lib.sync.sync_flow.Session") + @patch.multiple(SyncFlow, __abstractmethods__=set()) + def test_execute_all_steps(self, session_mock): + sync_flow = self.create_sync_flow() + sync_flow.compare_remote.return_value = False + sync_flow.gather_dependencies.return_value = ["A"] + result = sync_flow.execute() + + sync_flow.gather_resources.assert_called_once() + sync_flow.compare_remote.assert_called_once() + sync_flow.sync.assert_called_once() + sync_flow.gather_dependencies.assert_called_once() + self.assertEqual(result, ["A"]) + + @patch("samcli.lib.sync.sync_flow.Session") + @patch.multiple(SyncFlow, __abstractmethods__=set()) + def test_execute_skip_after_compare(self, session_mock): + sync_flow = self.create_sync_flow() + sync_flow.compare_remote.return_value = True + sync_flow.gather_dependencies.return_value = ["A"] + result = sync_flow.execute() + + sync_flow.gather_resources.assert_called_once() + sync_flow.compare_remote.assert_called_once() + sync_flow.sync.assert_not_called() + sync_flow.gather_dependencies.assert_not_called() + self.assertEqual(result, []) + + @patch("samcli.lib.sync.sync_flow.Session") + @patch.multiple(SyncFlow, __abstractmethods__=set()) + def test_set_up(self, session_mock): + sync_flow = self.create_sync_flow() + sync_flow.set_up() + session_mock.assert_called_once() + self.assertIsNotNone(sync_flow._session) + + @patch("samcli.lib.sync.sync_flow.Session") + @patch.multiple(SyncFlow, __abstractmethods__=set()) + def test_set_locks_with_distributor(self, session_mock): + sync_flow = self.create_sync_flow() + distributor = MagicMock() + locks = {"A": 1, "B": 2} + distributor.get_locks.return_value = locks + sync_flow.set_locks_with_distributor(distributor) + self.assertEqual(locks, sync_flow._locks) + + @patch.multiple(SyncFlow, __abstractmethods__=set()) + def test_get_lock_keys(self): + sync_flow = self.create_sync_flow() + sync_flow._get_resource_api_calls.return_value = [ResourceAPICall("A", "1"), ResourceAPICall("B", "2")] + result = sync_flow.get_lock_keys() + self.assertEqual(result, ["A_1", "B_2"]) + + @patch("samcli.lib.sync.sync_flow.LockChain") + @patch("samcli.lib.sync.sync_flow.Session") + @patch.multiple(SyncFlow, __abstractmethods__=set()) + def test_get_lock_chain(self, session_mock, lock_chain_mock): + sync_flow = self.create_sync_flow() + locks = {"A": 1, "B": 2} + sync_flow._locks = locks + result = sync_flow._get_lock_chain() + lock_chain_mock.assert_called_once_with(locks) + + @patch.multiple(SyncFlow, __abstractmethods__=set()) + def test_log_prefix(self): + sync_flow = self.create_sync_flow() + sync_flow._log_name = "A" + self.assertEqual(sync_flow.log_prefix, "SyncFlow [A]: ") + + @patch.multiple(SyncFlow, __abstractmethods__=set()) + def test_eq_true(self): + sync_flow_1 = self.create_sync_flow() + sync_flow_1._equality_keys = MagicMock() + sync_flow_1._equality_keys.return_value = "A" + sync_flow_2 = self.create_sync_flow() + sync_flow_2._equality_keys = MagicMock() + sync_flow_2._equality_keys.return_value = "A" + self.assertTrue(sync_flow_1 == sync_flow_2) + + @patch.multiple(SyncFlow, __abstractmethods__=set()) + def test_eq_false(self): + sync_flow_1 = self.create_sync_flow() + sync_flow_1._equality_keys = MagicMock() + sync_flow_1._equality_keys.return_value = "A" + sync_flow_2 = self.create_sync_flow() + sync_flow_2._equality_keys = MagicMock() + sync_flow_2._equality_keys.return_value = "B" + self.assertFalse(sync_flow_1 == sync_flow_2) + + @patch.multiple(SyncFlow, __abstractmethods__=set()) + def test_hash(self): + sync_flow = self.create_sync_flow() + sync_flow._equality_keys = MagicMock() + sync_flow._equality_keys.return_value = "A" + self.assertEqual(hash(sync_flow), hash((type(sync_flow), "A"))) diff --git a/tests/unit/lib/sync/test_sync_flow_executor.py b/tests/unit/lib/sync/test_sync_flow_executor.py new file mode 100644 index 0000000000..74840defe9 --- /dev/null +++ b/tests/unit/lib/sync/test_sync_flow_executor.py @@ -0,0 +1,262 @@ +from multiprocessing.managers import ValueProxy +from queue import Queue +from samcli.lib.sync.sync_flow import SyncFlow + +from botocore.exceptions import ClientError +from samcli.lib.providers.exceptions import MissingLocalDefinition +from samcli.lib.sync.exceptions import ( + MissingPhysicalResourceError, + NoLayerVersionsFoundError, + SyncFlowException, + MissingFunctionBuildDefinition, + InvalidRuntimeDefinitionForFunction, +) +from unittest import TestCase +from unittest.mock import ANY, MagicMock, call, patch + +from samcli.lib.sync.sync_flow_executor import ( + SyncFlowExecutor, + SyncFlowResult, + SyncFlowTask, + default_exception_handler, + HELP_TEXT_FOR_SYNC_INFRA, +) + + +class TestSyncFlowExecutor(TestCase): + def setUp(self): + self.thread_pool_executor_patch = patch("samcli.lib.sync.sync_flow_executor.ThreadPoolExecutor") + self.thread_pool_executor_mock = self.thread_pool_executor_patch.start() + self.thread_pool_executor = self.thread_pool_executor_mock.return_value + self.thread_pool_executor.__enter__.return_value = self.thread_pool_executor + self.lock_distributor_patch = patch("samcli.lib.sync.sync_flow_executor.LockDistributor") + self.lock_distributor_mock = self.lock_distributor_patch.start() + self.lock_distributor = self.lock_distributor_mock.return_value + self.executor = SyncFlowExecutor() + + def tearDown(self) -> None: + self.thread_pool_executor_patch.stop() + self.lock_distributor_patch.stop() + + @patch("samcli.lib.sync.sync_flow_executor.LOG") + def test_default_exception_handler_missing_physical_resource_error(self, log_mock): + sync_flow_exception = MagicMock(spec=SyncFlowException) + exception = MagicMock(spec=MissingPhysicalResourceError) + exception.resource_identifier = "Resource1" + sync_flow_exception.exception = exception + + default_exception_handler(sync_flow_exception) + log_mock.error.assert_called_once_with( + "Cannot find resource %s in remote.%s", "Resource1", HELP_TEXT_FOR_SYNC_INFRA + ) + + @patch("samcli.lib.sync.sync_flow_executor.LOG") + def test_default_exception_handler_client_error_valid(self, log_mock): + sync_flow_exception = MagicMock(spec=SyncFlowException) + exception = MagicMock(spec=ClientError) + exception.resource_identifier = "Resource1" + exception.response = {"Error": {"Code": "ResourceNotFoundException", "Message": "MessageContent"}} + sync_flow_exception.exception = exception + + default_exception_handler(sync_flow_exception) + log_mock.error.assert_has_calls( + [call("Cannot find resource in remote.%s", HELP_TEXT_FOR_SYNC_INFRA), call("MessageContent")] + ) + + @patch("samcli.lib.sync.sync_flow_executor.LOG") + def test_default_exception_no_layer_versions_found(self, log_mock): + sync_flow_exception = MagicMock(spec=SyncFlowException) + exception = MagicMock(spec=NoLayerVersionsFoundError) + exception.layer_name_arn = "layer_name" + sync_flow_exception.exception = exception + + default_exception_handler(sync_flow_exception) + log_mock.error.assert_has_calls( + [ + call( + "Cannot find any versions for layer %s.%s", + exception.layer_name_arn, + HELP_TEXT_FOR_SYNC_INFRA, + ) + ] + ) + + @patch("samcli.lib.sync.sync_flow_executor.LOG") + def test_default_exception_handler_missing_function_build_exception(self, log_mock): + sync_flow_exception = MagicMock(spec=SyncFlowException) + exception = MagicMock(spec=MissingFunctionBuildDefinition) + exception.function_logical_id = "function_logical_id" + sync_flow_exception.exception = exception + + default_exception_handler(sync_flow_exception) + log_mock.error.assert_has_calls( + [ + call( + "Cannot find build definition for function %s.%s", + exception.function_logical_id, + HELP_TEXT_FOR_SYNC_INFRA, + ) + ] + ) + + @patch("samcli.lib.sync.sync_flow_executor.LOG") + def test_default_exception_missing_local_definition(self, log_mock): + sync_flow_exception = MagicMock(spec=SyncFlowException) + exception = MagicMock(spec=MissingLocalDefinition) + exception.resource_identifier = "resource" + exception.property_name = "property" + sync_flow_exception.exception = exception + + default_exception_handler(sync_flow_exception) + log_mock.error.assert_has_calls( + [ + call( + "Resource %s does not have %s specified. Skipping the sync.%s", + exception.resource_identifier, + exception.property_name, + HELP_TEXT_FOR_SYNC_INFRA, + ) + ] + ) + + @patch("samcli.lib.sync.sync_flow_executor.LOG") + def test_default_exception_handler_invalid_runtime_exception(self, log_mock): + sync_flow_exception = MagicMock(spec=SyncFlowException) + exception = MagicMock(spec=InvalidRuntimeDefinitionForFunction) + exception.function_logical_id = "function_logical_id" + sync_flow_exception.exception = exception + + default_exception_handler(sync_flow_exception) + log_mock.error.assert_has_calls( + [ + call( + "No Runtime information found for function resource named %s", + exception.function_logical_id, + ) + ] + ) + + @patch("samcli.lib.sync.sync_flow_executor.LOG") + def test_default_exception_handler_client_error_invalid_code(self, log_mock): + sync_flow_exception = MagicMock(spec=SyncFlowException) + exception = ClientError({"Error": {"Code": "RandomException", "Message": "MessageContent"}}, "") + exception.resource_identifier = "Resource1" + sync_flow_exception.exception = exception + with self.assertRaises(ClientError): + default_exception_handler(sync_flow_exception) + + @patch("samcli.lib.sync.sync_flow_executor.LOG") + def test_default_exception_handler_client_error_invalid_exception(self, log_mock): + sync_flow_exception = MagicMock(spec=SyncFlowException) + + class RandomException(Exception): + pass + + exception = RandomException() + exception.resource_identifier = "Resource1" + sync_flow_exception.exception = exception + with self.assertRaises(RandomException): + default_exception_handler(sync_flow_exception) + + @patch("samcli.lib.sync.sync_flow_executor.time.time") + @patch("samcli.lib.sync.sync_flow_executor.SyncFlowTask") + def test_add_sync_flow(self, task_mock, time_mock): + add_sync_flow_task_mock = MagicMock() + task = MagicMock() + task_mock.return_value = task + time_mock.return_value = 1000 + self.executor._add_sync_flow_task = add_sync_flow_task_mock + sync_flow = MagicMock() + + self.executor.add_sync_flow(sync_flow, False) + + task_mock.assert_called_once_with(sync_flow, False) + add_sync_flow_task_mock.assert_called_once_with(task) + + def test_add_sync_flow_task(self): + sync_flow = MagicMock() + task = SyncFlowTask(sync_flow, False) + + self.executor._add_sync_flow_task(task) + + sync_flow.set_locks_with_distributor.assert_called_once_with(self.executor._lock_distributor) + + queue_task = self.executor._flow_queue.get() + self.assertEqual(sync_flow, queue_task.sync_flow) + + def test_add_sync_flow_task_dedup(self): + sync_flow = MagicMock() + + task1 = SyncFlowTask(sync_flow, True) + task2 = SyncFlowTask(sync_flow, True) + + self.executor._add_sync_flow_task(task1) + self.executor._add_sync_flow_task(task2) + + sync_flow.set_locks_with_distributor.assert_called_once_with(self.executor._lock_distributor) + + queue_task = self.executor._flow_queue.get() + self.assertEqual(sync_flow, queue_task.sync_flow) + self.assertTrue(self.executor._flow_queue.empty()) + + def test_is_running_without_manager(self): + self.executor._running_flag = True + self.assertTrue(self.executor.is_running()) + + @patch("samcli.lib.sync.sync_flow_executor.time.time") + @patch("samcli.lib.sync.sync_flow_executor.time.sleep") + def test_execute_high_level_logic(self, sleep_mock, time_mock): + exception_handler_mock = MagicMock() + time_mock.return_value = 1001 + + flow1 = MagicMock() + flow2 = MagicMock() + flow3 = MagicMock() + + task1 = SyncFlowTask(flow1, False) + task2 = SyncFlowTask(flow2, False) + task3 = SyncFlowTask(flow3, False) + + result1 = SyncFlowResult(flow1, [flow3]) + + future1 = MagicMock() + future2 = MagicMock() + future3 = MagicMock() + + exception1 = MagicMock(spec=Exception) + sync_flow_exception = MagicMock(spec=SyncFlowException) + sync_flow_exception.sync_flow = flow2 + sync_flow_exception.exception = exception1 + + future1.done.side_effect = [False, False, True] + future1.exception.return_value = None + future1.result.return_value = result1 + + future2.done.side_effect = [False, False, False, True] + future2.exception.return_value = sync_flow_exception + + future3.done.side_effect = [False, False, False, True] + future3.exception.return_value = None + + self.thread_pool_executor.submit = MagicMock() + self.thread_pool_executor.submit.side_effect = [future1, future2, future3] + + self.executor._flow_queue.put(task1) + self.executor._flow_queue.put(task2) + + self.executor.add_sync_flow = MagicMock() + self.executor.add_sync_flow.side_effect = lambda x: self.executor._flow_queue.put(task3) + + self.executor.execute(exception_handler=exception_handler_mock) + + self.thread_pool_executor.submit.assert_has_calls( + [ + call(SyncFlowExecutor._sync_flow_execute_wrapper, flow1), + call(SyncFlowExecutor._sync_flow_execute_wrapper, flow2), + call(SyncFlowExecutor._sync_flow_execute_wrapper, flow3), + ] + ) + self.executor.add_sync_flow.assert_called_once_with(flow3) + + exception_handler_mock.assert_called_once_with(sync_flow_exception) + self.assertEqual(len(sleep_mock.mock_calls), 6) diff --git a/tests/unit/lib/sync/test_sync_flow_factory.py b/tests/unit/lib/sync/test_sync_flow_factory.py new file mode 100644 index 0000000000..41545766bb --- /dev/null +++ b/tests/unit/lib/sync/test_sync_flow_factory.py @@ -0,0 +1,138 @@ +from unittest import TestCase +from unittest.mock import MagicMock, patch, Mock + +from samcli.lib.sync.sync_flow_factory import SyncFlowFactory + + +class TestSyncFlowFactory(TestCase): + def create_factory(self, auto_dependency_layer: bool = False): + factory = SyncFlowFactory( + build_context=MagicMock(), + deploy_context=MagicMock(), + stacks=[MagicMock(), MagicMock()], + auto_dependency_layer=auto_dependency_layer, + ) + return factory + + @patch("samcli.lib.sync.sync_flow_factory.get_physical_id_mapping") + @patch("samcli.lib.sync.sync_flow_factory.get_boto_resource_provider_with_config") + def test_load_physical_id_mapping(self, get_boto_resource_provider_mock, get_physical_id_mapping_mock): + get_physical_id_mapping_mock.return_value = {"Resource1": "PhysicalResource1", "Resource2": "PhysicalResource2"} + + factory = self.create_factory() + factory.load_physical_id_mapping() + + self.assertEqual(len(factory._physical_id_mapping), 2) + self.assertEqual( + factory._physical_id_mapping, {"Resource1": "PhysicalResource1", "Resource2": "PhysicalResource2"} + ) + + @patch("samcli.lib.sync.sync_flow_factory.ImageFunctionSyncFlow") + @patch("samcli.lib.sync.sync_flow_factory.ZipFunctionSyncFlow") + def test_create_lambda_flow_zip(self, zip_function_mock, image_function_mock): + factory = self.create_factory() + resource = {"Properties": {"PackageType": "Zip"}} + result = factory._create_lambda_flow("Function1", resource) + self.assertEqual(result, zip_function_mock.return_value) + + @patch("samcli.lib.sync.sync_flow_factory.ImageFunctionSyncFlow") + @patch("samcli.lib.sync.sync_flow_factory.ZipFunctionSyncFlow") + @patch("samcli.lib.sync.sync_flow_factory.AutoDependencyLayerParentSyncFlow") + def test_create_lambda_flow_zip_with_auto_dependency_layer( + self, auto_dependency_layer_mock, zip_function_mock, image_function_mock + ): + factory = self.create_factory(True) + resource = {"Properties": {"PackageType": "Zip", "Runtime": "python3.8"}} + result = factory._create_lambda_flow("Function1", resource) + self.assertEqual(result, auto_dependency_layer_mock.return_value) + + @patch("samcli.lib.sync.sync_flow_factory.ImageFunctionSyncFlow") + @patch("samcli.lib.sync.sync_flow_factory.ZipFunctionSyncFlow") + @patch("samcli.lib.sync.sync_flow_factory.AutoDependencyLayerParentSyncFlow") + def test_create_lambda_flow_zip_with_unsupported_runtime_auto_dependency_layer( + self, auto_dependency_layer_mock, zip_function_mock, image_function_mock + ): + factory = self.create_factory(True) + resource = {"Properties": {"PackageType": "Zip", "Runtime": "ruby2.7"}} + result = factory._create_lambda_flow("Function1", resource) + self.assertEqual(result, zip_function_mock.return_value) + + @patch("samcli.lib.sync.sync_flow_factory.ImageFunctionSyncFlow") + @patch("samcli.lib.sync.sync_flow_factory.ZipFunctionSyncFlow") + def test_create_lambda_flow_image(self, zip_function_mock, image_function_mock): + factory = self.create_factory() + resource = {"Properties": {"PackageType": "Image"}} + result = factory._create_lambda_flow("Function1", resource) + self.assertEqual(result, image_function_mock.return_value) + + @patch("samcli.lib.sync.sync_flow_factory.LayerSyncFlow") + def test_create_layer_flow(self, layer_sync_mock): + factory = self.create_factory() + result = factory._create_layer_flow("Layer1", {}) + self.assertEqual(result, layer_sync_mock.return_value) + + @patch("samcli.lib.sync.sync_flow_factory.ImageFunctionSyncFlow") + @patch("samcli.lib.sync.sync_flow_factory.ZipFunctionSyncFlow") + def test_create_lambda_flow_other(self, zip_function_mock, image_function_mock): + factory = self.create_factory() + resource = {"Properties": {"PackageType": "Other"}} + result = factory._create_lambda_flow("Function1", resource) + self.assertEqual(result, None) + + @patch("samcli.lib.sync.sync_flow_factory.RestApiSyncFlow") + def test_create_rest_api_flow(self, rest_api_sync_mock): + factory = self.create_factory() + result = factory._create_rest_api_flow("API1", {}) + self.assertEqual(result, rest_api_sync_mock.return_value) + + @patch("samcli.lib.sync.sync_flow_factory.HttpApiSyncFlow") + def test_create_api_flow(self, http_api_sync_mock): + factory = self.create_factory() + result = factory._create_api_flow("API1", {}) + self.assertEqual(result, http_api_sync_mock.return_value) + + @patch("samcli.lib.sync.sync_flow_factory.StepFunctionsSyncFlow") + def test_create_stepfunctions_flow(self, stepfunctions_sync_mock): + factory = self.create_factory() + result = factory._create_stepfunctions_flow("StateMachine1", {}) + self.assertEqual(result, stepfunctions_sync_mock.return_value) + + @patch("samcli.lib.sync.sync_flow_factory.get_resource_by_id") + def test_create_sync_flow(self, get_resource_by_id_mock): + factory = self.create_factory() + + sync_flow = MagicMock() + resource_identifier = MagicMock() + get_resource_by_id = MagicMock() + get_resource_by_id_mock.return_value = get_resource_by_id + generator_mock = MagicMock() + generator_mock.return_value = sync_flow + + get_generator_function_mock = MagicMock() + get_generator_function_mock.return_value = generator_mock + factory._get_generator_function = get_generator_function_mock + + result = factory.create_sync_flow(resource_identifier) + + self.assertEqual(result, sync_flow) + generator_mock.assert_called_once_with(factory, resource_identifier, get_resource_by_id) + + @patch("samcli.lib.sync.sync_flow_factory.get_resource_by_id") + def test_create_unknown_resource_sync_flow(self, get_resource_by_id_mock): + get_resource_by_id_mock.return_value = None + factory = self.create_factory() + self.assertIsNone(factory.create_sync_flow(MagicMock())) + + @patch("samcli.lib.sync.sync_flow_factory.get_resource_by_id") + def test_create_none_generator_sync_flow(self, get_resource_by_id_mock): + factory = self.create_factory() + + resource_identifier = MagicMock() + get_resource_by_id = MagicMock() + get_resource_by_id_mock.return_value = get_resource_by_id + + get_generator_function_mock = MagicMock() + get_generator_function_mock.return_value = None + factory._get_generator_function = get_generator_function_mock + + self.assertIsNone(factory.create_sync_flow(resource_identifier)) diff --git a/tests/unit/lib/sync/test_watch_manager.py b/tests/unit/lib/sync/test_watch_manager.py new file mode 100644 index 0000000000..d91a9d970d --- /dev/null +++ b/tests/unit/lib/sync/test_watch_manager.py @@ -0,0 +1,239 @@ +from unittest.case import TestCase +from unittest.mock import MagicMock, patch, ANY +from samcli.lib.sync.watch_manager import WatchManager +from samcli.lib.providers.exceptions import MissingCodeUri, MissingLocalDefinition +from samcli.lib.sync.exceptions import MissingPhysicalResourceError, SyncFlowException + + +class TestWatchManager(TestCase): + def setUp(self) -> None: + self.template = "template.yaml" + self.path_observer_patch = patch("samcli.lib.sync.watch_manager.HandlerObserver") + self.path_observer_mock = self.path_observer_patch.start() + self.path_observer = self.path_observer_mock.return_value + self.executor_patch = patch("samcli.lib.sync.watch_manager.ContinuousSyncFlowExecutor") + self.executor_mock = self.executor_patch.start() + self.executor = self.executor_mock.return_value + self.colored_patch = patch("samcli.lib.sync.watch_manager.Colored") + self.colored_mock = self.colored_patch.start() + self.colored = self.colored_mock.return_value + self.build_context = MagicMock() + self.package_context = MagicMock() + self.deploy_context = MagicMock() + self.watch_manager = WatchManager( + self.template, self.build_context, self.package_context, self.deploy_context, False + ) + + def tearDown(self) -> None: + self.path_observer_patch.stop() + self.executor_patch.stop() + self.colored_patch.stop() + + def test_queue_infra_sync(self): + self.assertFalse(self.watch_manager._waiting_infra_sync) + self.watch_manager.queue_infra_sync() + self.assertTrue(self.watch_manager._waiting_infra_sync) + + @patch("samcli.lib.sync.watch_manager.SamLocalStackProvider.get_stacks") + @patch("samcli.lib.sync.watch_manager.SyncFlowFactory") + @patch("samcli.lib.sync.watch_manager.CodeTriggerFactory") + def test_update_stacks( + self, trigger_factory_mock: MagicMock, sync_flow_factory_mock: MagicMock, get_stacks_mock: MagicMock + ): + stacks = [MagicMock()] + get_stacks_mock.return_value = [ + stacks, + ] + self.watch_manager._update_stacks() + get_stacks_mock.assert_called_once_with(self.template) + sync_flow_factory_mock.assert_called_once_with(self.build_context, self.deploy_context, stacks, False) + sync_flow_factory_mock.return_value.load_physical_id_mapping.assert_called_once_with() + trigger_factory_mock.assert_called_once_with(stacks) + + @patch("samcli.lib.sync.watch_manager.get_all_resource_ids") + def test_add_code_triggers(self, get_all_resource_ids_mock): + resource_ids = [MagicMock(), MagicMock(), MagicMock(), MagicMock(), MagicMock()] + get_all_resource_ids_mock.return_value = resource_ids + + trigger_1 = MagicMock() + trigger_2 = MagicMock() + + trigger_factory = MagicMock() + trigger_factory.create_trigger.side_effect = [ + trigger_1, + None, + MissingCodeUri(), + trigger_2, + MissingLocalDefinition(MagicMock(), MagicMock()), + ] + self.watch_manager._stacks = [MagicMock()] + self.watch_manager._trigger_factory = trigger_factory + + on_code_change_wrapper_mock = MagicMock() + self.watch_manager._on_code_change_wrapper = on_code_change_wrapper_mock + + self.watch_manager._add_code_triggers() + + trigger_factory.create_trigger.assert_any_call(resource_ids[0], on_code_change_wrapper_mock.return_value) + trigger_factory.create_trigger.assert_any_call(resource_ids[1], on_code_change_wrapper_mock.return_value) + + on_code_change_wrapper_mock.assert_any_call(resource_ids[0]) + on_code_change_wrapper_mock.assert_any_call(resource_ids[1]) + + self.path_observer.schedule_handlers.assert_any_call(trigger_1.get_path_handlers.return_value) + self.path_observer.schedule_handlers.assert_any_call(trigger_2.get_path_handlers.return_value) + self.assertEqual(self.path_observer.schedule_handlers.call_count, 2) + + @patch("samcli.lib.sync.watch_manager.TemplateTrigger") + def test_add_template_trigger(self, template_trigger_mock): + trigger = template_trigger_mock.return_value + + self.watch_manager._add_template_trigger() + + template_trigger_mock.assert_called_once_with(self.template, ANY) + self.path_observer.schedule_handlers.assert_any_call(trigger.get_path_handlers.return_value) + + def test_execute_infra_sync(self): + self.watch_manager._execute_infra_context() + self.build_context.set_up.assert_called_once_with() + self.build_context.run.assert_called_once_with() + self.package_context.run.assert_called_once_with() + self.deploy_context.run.assert_called_once_with() + + @patch("samcli.lib.sync.watch_manager.threading.Thread") + def test_start_code_sync(self, thread_mock): + self.watch_manager._start_code_sync() + thread = thread_mock.return_value + + self.assertEqual(self.watch_manager._executor_thread, thread) + thread.start.assert_called_once_with() + + def test_stop_code_sync(self): + thread = MagicMock() + thread.is_alive.return_value = True + self.watch_manager._executor_thread = thread + + self.watch_manager._stop_code_sync() + + self.executor.stop.assert_called_once_with() + thread.join.assert_called_once_with() + + def test_start(self): + queue_infra_sync_mock = MagicMock() + _start_mock = MagicMock() + stop_code_sync_mock = MagicMock() + + self.watch_manager.queue_infra_sync = queue_infra_sync_mock + self.watch_manager._start = _start_mock + self.watch_manager._stop_code_sync = stop_code_sync_mock + + _start_mock.side_effect = KeyboardInterrupt() + + self.watch_manager.start() + + self.path_observer.stop.assert_called_once_with() + stop_code_sync_mock.assert_called_once_with() + + @patch("samcli.lib.sync.watch_manager.time.sleep") + def test__start(self, sleep_mock): + sleep_mock.side_effect = KeyboardInterrupt() + + stop_code_sync_mock = MagicMock() + execute_infra_sync_mock = MagicMock() + + update_stacks_mock = MagicMock() + add_template_trigger_mock = MagicMock() + add_code_trigger_mock = MagicMock() + start_code_sync_mock = MagicMock() + + self.watch_manager._stop_code_sync = stop_code_sync_mock + self.watch_manager._execute_infra_context = execute_infra_sync_mock + self.watch_manager._update_stacks = update_stacks_mock + self.watch_manager._add_template_trigger = add_template_trigger_mock + self.watch_manager._add_code_triggers = add_code_trigger_mock + self.watch_manager._start_code_sync = start_code_sync_mock + + self.watch_manager._waiting_infra_sync = True + with self.assertRaises(KeyboardInterrupt): + self.watch_manager._start() + + self.path_observer.start.assert_called_once_with() + self.assertFalse(self.watch_manager._waiting_infra_sync) + + stop_code_sync_mock.assert_called_once_with() + execute_infra_sync_mock.assert_called_once_with() + update_stacks_mock.assert_called_once_with() + add_template_trigger_mock.assert_called_once_with() + add_code_trigger_mock.assert_called_once_with() + start_code_sync_mock.assert_called_once_with() + + self.path_observer.unschedule_all.assert_called_once_with() + + self.path_observer.start.assert_called_once_with() + + @patch("samcli.lib.sync.watch_manager.time.sleep") + def test__start_infra_exception(self, sleep_mock): + sleep_mock.side_effect = KeyboardInterrupt() + + stop_code_sync_mock = MagicMock() + execute_infra_sync_mock = MagicMock() + execute_infra_sync_mock.side_effect = Exception() + + update_stacks_mock = MagicMock() + add_template_trigger_mock = MagicMock() + add_code_trigger_mock = MagicMock() + start_code_sync_mock = MagicMock() + + self.watch_manager._stop_code_sync = stop_code_sync_mock + self.watch_manager._execute_infra_context = execute_infra_sync_mock + self.watch_manager._update_stacks = update_stacks_mock + self.watch_manager._add_template_trigger = add_template_trigger_mock + self.watch_manager._add_code_triggers = add_code_trigger_mock + self.watch_manager._start_code_sync = start_code_sync_mock + + self.watch_manager._waiting_infra_sync = True + with self.assertRaises(KeyboardInterrupt): + self.watch_manager._start() + + self.path_observer.start.assert_called_once_with() + self.assertFalse(self.watch_manager._waiting_infra_sync) + + stop_code_sync_mock.assert_called_once_with() + execute_infra_sync_mock.assert_called_once_with() + add_template_trigger_mock.assert_called_once_with() + + update_stacks_mock.assert_not_called() + add_code_trigger_mock.assert_not_called() + start_code_sync_mock.assert_not_called() + + self.path_observer.unschedule_all.assert_called_once_with() + + self.path_observer.start.assert_called_once_with() + + def test_on_code_change_wrapper(self): + flow1 = MagicMock() + resource_id_mock = MagicMock() + factory_mock = MagicMock() + + self.watch_manager._sync_flow_factory = factory_mock + factory_mock.create_sync_flow.return_value = flow1 + + callback = self.watch_manager._on_code_change_wrapper(resource_id_mock) + + callback() + + self.executor.add_delayed_sync_flow.assert_any_call(flow1, dedup=True, wait_time=ANY) + + def test_watch_sync_flow_exception_handler_missing_physical(self): + sync_flow = MagicMock() + sync_flow_exception = MagicMock(spec=SyncFlowException) + exception = MagicMock(spec=MissingPhysicalResourceError) + sync_flow_exception.exception = exception + sync_flow_exception.sync_flow = sync_flow + + queue_infra_sync_mock = MagicMock() + self.watch_manager.queue_infra_sync = queue_infra_sync_mock + + self.watch_manager._watch_sync_flow_exception_handler(sync_flow_exception) + + queue_infra_sync_mock.assert_called_once_with() diff --git a/tests/unit/lib/telemetry/test_metric.py b/tests/unit/lib/telemetry/test_metric.py index 3493b7bed4..414b769675 100644 --- a/tests/unit/lib/telemetry/test_metric.py +++ b/tests/unit/lib/telemetry/test_metric.py @@ -143,6 +143,7 @@ def setUp(self): self.context_mock.debug = False self.context_mock.region = "myregion" self.context_mock.command_path = "fakesam local invoke" + self.context_mock.experimental = False # Enable telemetry so we can actually run the tests self.gc_instance_mock.telemetry_enabled = True @@ -297,6 +298,7 @@ def real_fn(): @patch("samcli.lib.telemetry.metric.Context") def test_must_return_value_from_decorated_function(self, ContextMock): + ContextMock.get_current_context.return_value = self.context_mock expected_value = "some return value" def real_fn(): @@ -317,6 +319,8 @@ def real_fn(*args, **kwargs): @patch("samcli.lib.telemetry.metric.Context") def test_must_decorate_functions(self, ContextMock): + ContextMock.get_current_context.return_value = self.context_mock + @track_command def real_fn(a, b=None): return "{} {}".format(a, b) diff --git a/tests/unit/lib/utils/test_boto_utils.py b/tests/unit/lib/utils/test_boto_utils.py new file mode 100644 index 0000000000..2cd37352a7 --- /dev/null +++ b/tests/unit/lib/utils/test_boto_utils.py @@ -0,0 +1,85 @@ +from unittest import TestCase +from unittest.mock import patch, Mock + +from parameterized import parameterized + +from samcli.lib.utils.boto_utils import ( + get_boto_config_with_user_agent, + get_boto_client_provider_with_config, + get_boto_resource_provider_with_config, +) + +TEST_VERSION = "1.0.0" + + +class TestBotoUtils(TestCase): + @parameterized.expand([(True,), (False,)]) + @patch("samcli.lib.utils.boto_utils.GlobalConfig") + @patch("samcli.lib.utils.boto_utils.__version__", TEST_VERSION) + def test_get_boto_config_with_user_agent( + self, + telemetry_enabled, + patched_global_config, + ): + given_global_config_instance = Mock() + patched_global_config.return_value = given_global_config_instance + + given_global_config_instance.telemetry_enabled = telemetry_enabled + given_region_name = "us-west-2" + + config = get_boto_config_with_user_agent(region_name=given_region_name) + + self.assertEqual(given_region_name, config.region_name) + + if telemetry_enabled: + self.assertEqual( + config.user_agent_extra, f"aws-sam-cli/{TEST_VERSION}/{given_global_config_instance.installation_id}" + ) + else: + self.assertEqual(config.user_agent_extra, f"aws-sam-cli/{TEST_VERSION}") + + @patch("samcli.lib.utils.boto_utils.get_boto_config_with_user_agent") + @patch("samcli.lib.utils.boto_utils.boto3") + def test_get_boto_client_provider_with_config(self, patched_boto3, patched_get_config): + given_config = Mock() + patched_get_config.return_value = given_config + + given_config_param = Mock() + given_profile = Mock() + given_region = Mock() + client_generator = get_boto_client_provider_with_config( + region=given_region, profile=given_profile, param=given_config_param + ) + + given_service_client = Mock() + patched_boto3.session.Session().client.return_value = given_service_client + + client = client_generator("service") + + self.assertEqual(client, given_service_client) + patched_get_config.assert_called_with(param=given_config_param) + patched_boto3.session.Session.assert_called_with(region_name=given_region, profile_name=given_profile) + patched_boto3.session.Session().client.assert_called_with("service", config=given_config) + + @patch("samcli.lib.utils.boto_utils.get_boto_config_with_user_agent") + @patch("samcli.lib.utils.boto_utils.boto3") + def test_get_boto_resource_provider_with_config(self, patched_boto3, patched_get_config): + given_config = Mock() + patched_get_config.return_value = given_config + + given_config_param = Mock() + given_profile = Mock() + given_region = Mock() + client_generator = get_boto_resource_provider_with_config( + region=given_region, profile=given_profile, param=given_config_param + ) + + given_service_client = Mock() + patched_boto3.session.Session().resource.return_value = given_service_client + + client = client_generator("service") + + self.assertEqual(client, given_service_client) + patched_get_config.assert_called_with(param=given_config_param) + patched_boto3.session.Session.assert_called_with(region_name=given_region, profile_name=given_profile) + patched_boto3.session.Session().resource.assert_called_with("service", config=given_config) diff --git a/tests/unit/lib/utils/test_cloudformation.py b/tests/unit/lib/utils/test_cloudformation.py new file mode 100644 index 0000000000..f925e295ba --- /dev/null +++ b/tests/unit/lib/utils/test_cloudformation.py @@ -0,0 +1,119 @@ +from unittest import TestCase +from unittest.mock import patch, Mock, ANY + +from botocore.exceptions import ClientError + +from samcli.lib.utils.cloudformation import ( + CloudFormationResourceSummary, + get_physical_id_mapping, + get_resource_summaries, + get_resource_summary, +) + + +class TestCloudFormationResourceSummary(TestCase): + def test_cfn_resource_summary(self): + given_type = "type" + given_logical_id = "logical_id" + given_physical_id = "physical_id" + + resource_summary = CloudFormationResourceSummary(given_type, given_logical_id, given_physical_id) + + self.assertEqual(given_type, resource_summary.resource_type) + self.assertEqual(given_logical_id, resource_summary.logical_resource_id) + self.assertEqual(given_physical_id, resource_summary.physical_resource_id) + + +class TestCloudformationUtils(TestCase): + @patch("samcli.lib.utils.cloudformation.get_resource_summaries") + def test_get_physical_id_mapping(self, patched_get_resource_summaries): + patched_get_resource_summaries.return_value = [ + CloudFormationResourceSummary("", "Logical1", "Physical1"), + CloudFormationResourceSummary("", "Logical2", "Physical2"), + CloudFormationResourceSummary("", "Logical3", "Physical3"), + ] + + given_resource_provider = Mock() + given_resource_types = Mock() + given_stack_name = "stack_name" + physical_id_mapping = get_physical_id_mapping(given_resource_provider, given_stack_name, given_resource_types) + + self.assertEqual( + physical_id_mapping, + { + "Logical1": "Physical1", + "Logical2": "Physical2", + "Logical3": "Physical3", + }, + ) + + patched_get_resource_summaries.assert_called_with( + given_resource_provider, given_stack_name, given_resource_types + ) + + def test_get_resource_summaries(self): + resource_provider_mock = Mock() + given_stack_name = "stack_name" + given_resource_types = {"ResourceType0"} + + given_stack_resource_array = [ + Mock( + physical_resource_id="physical_id_1", logical_resource_id="logical_id_1", resource_type="ResourceType0" + ), + Mock( + physical_resource_id="physical_id_2", logical_resource_id="logical_id_2", resource_type="ResourceType0" + ), + Mock( + physical_resource_id="physical_id_3", logical_resource_id="logical_id_3", resource_type="ResourceType1" + ), + ] + + resource_provider_mock(ANY).Stack(ANY).resource_summaries.all.return_value = given_stack_resource_array + + resource_summaries = get_resource_summaries(resource_provider_mock, given_stack_name, given_resource_types) + + self.assertEqual(len(resource_summaries), 2) + self.assertEqual( + resource_summaries, + [ + CloudFormationResourceSummary("ResourceType0", "logical_id_1", "physical_id_1"), + CloudFormationResourceSummary("ResourceType0", "logical_id_2", "physical_id_2"), + ], + ) + + resource_provider_mock.assert_called_with("cloudformation") + resource_provider_mock(ANY).Stack.assert_called_with(given_stack_name) + resource_provider_mock(ANY).Stack(ANY).resource_summaries.all.assert_called_once() + + def test_get_resource_summary(self): + resource_provider_mock = Mock() + given_stack_name = "stack_name" + given_resource_logical_id = "logical_id_1" + + given_resource_type = "ResourceType0" + given_physical_id = "physical_id_1" + resource_provider_mock(ANY).StackResource.return_value = Mock( + physical_resource_id=given_physical_id, + logical_resource_id=given_resource_logical_id, + resource_type=given_resource_type, + ) + + resource_summary = get_resource_summary(resource_provider_mock, given_stack_name, given_resource_logical_id) + + self.assertEqual(resource_summary.resource_type, given_resource_type) + self.assertEqual(resource_summary.logical_resource_id, given_resource_logical_id) + self.assertEqual(resource_summary.physical_resource_id, given_physical_id) + + resource_provider_mock.assert_called_with("cloudformation") + resource_provider_mock(ANY).StackResource.assert_called_with(given_stack_name, given_resource_logical_id) + + def test_get_resource_summary_fail(self): + resource_provider_mock = Mock() + given_stack_name = "stack_name" + given_resource_logical_id = "logical_id_1" + + resource_provider_mock(ANY).StackResource.side_effect = ClientError({}, "operation") + + resource_summary = get_resource_summary(resource_provider_mock, given_stack_name, given_resource_logical_id) + + self.assertIsNone(resource_summary) diff --git a/tests/unit/lib/utils/test_code_trigger_factory.py b/tests/unit/lib/utils/test_code_trigger_factory.py new file mode 100644 index 0000000000..fc250aae85 --- /dev/null +++ b/tests/unit/lib/utils/test_code_trigger_factory.py @@ -0,0 +1,72 @@ +from parameterized import parameterized +from unittest.case import TestCase +from unittest.mock import MagicMock, patch, ANY +from samcli.lib.utils.code_trigger_factory import CodeTriggerFactory +from samcli.lib.providers.provider import ResourceIdentifier + + +class TestCodeTriggerFactory(TestCase): + def setUp(self): + self.stacks = [MagicMock(), MagicMock()] + self.factory = CodeTriggerFactory(self.stacks) + + @patch("samcli.lib.utils.code_trigger_factory.LambdaZipCodeTrigger") + def test_create_zip_function_trigger(self, trigger_mock): + on_code_change_mock = MagicMock() + resource_identifier = ResourceIdentifier("Function1") + resource = {"Properties": {"PackageType": "Zip"}} + result = self.factory._create_lambda_trigger(resource_identifier, "Type", resource, on_code_change_mock) + self.assertEqual(result, trigger_mock.return_value) + trigger_mock.assert_called_once_with(resource_identifier, self.stacks, on_code_change_mock) + + @patch("samcli.lib.utils.code_trigger_factory.LambdaImageCodeTrigger") + def test_create_image_function_trigger(self, trigger_mock): + on_code_change_mock = MagicMock() + resource_identifier = ResourceIdentifier("Function1") + resource = {"Properties": {"PackageType": "Image"}} + result = self.factory._create_lambda_trigger(resource_identifier, "Type", resource, on_code_change_mock) + self.assertEqual(result, trigger_mock.return_value) + trigger_mock.assert_called_once_with(resource_identifier, self.stacks, on_code_change_mock) + + @patch("samcli.lib.utils.code_trigger_factory.LambdaLayerCodeTrigger") + def test_create_layer_trigger(self, trigger_mock): + on_code_change_mock = MagicMock() + resource_identifier = ResourceIdentifier("Layer1") + result = self.factory._create_layer_trigger(resource_identifier, "Type", {}, on_code_change_mock) + self.assertEqual(result, trigger_mock.return_value) + trigger_mock.assert_called_once_with(resource_identifier, self.stacks, on_code_change_mock) + + @patch("samcli.lib.utils.code_trigger_factory.DefinitionCodeTrigger") + def test_create_definition_trigger(self, trigger_mock): + on_code_change_mock = MagicMock() + resource_identifier = ResourceIdentifier("API1") + resource_type = "AWS::Serverless::Api" + result = self.factory._create_definition_code_trigger( + resource_identifier, resource_type, {}, on_code_change_mock + ) + self.assertEqual(result, trigger_mock.return_value) + trigger_mock.assert_called_once_with(resource_identifier, resource_type, self.stacks, on_code_change_mock) + + @patch("samcli.lib.utils.code_trigger_factory.get_resource_by_id") + @patch("samcli.lib.utils.resource_type_based_factory.get_resource_by_id") + def test_create_trigger(self, get_resource_by_id_mock, parent_get_resource_by_id_mock): + code_trigger = MagicMock() + resource_identifier = MagicMock() + get_resource_by_id = {"Type": "AWS::Serverless::Api"} + get_resource_by_id_mock.return_value = get_resource_by_id + parent_get_resource_by_id_mock.return_value = get_resource_by_id + generator_mock = MagicMock() + generator_mock.return_value = code_trigger + + on_code_change_mock = MagicMock() + + get_generator_function_mock = MagicMock() + get_generator_function_mock.return_value = generator_mock + self.factory._get_generator_function = get_generator_function_mock + + result = self.factory.create_trigger(resource_identifier, on_code_change_mock) + + self.assertEqual(result, code_trigger) + generator_mock.assert_called_once_with( + self.factory, resource_identifier, "AWS::Serverless::Api", get_resource_by_id, on_code_change_mock + ) diff --git a/tests/unit/lib/utils/test_definition_validator.py b/tests/unit/lib/utils/test_definition_validator.py new file mode 100644 index 0000000000..726c33307b --- /dev/null +++ b/tests/unit/lib/utils/test_definition_validator.py @@ -0,0 +1,61 @@ +from parameterized import parameterized +from unittest.case import TestCase +from unittest.mock import MagicMock, patch, ANY +from samcli.lib.utils.definition_validator import DefinitionValidator + + +class TestDefinitionValidator(TestCase): + def setUp(self) -> None: + self.path = MagicMock() + + @patch("samcli.lib.utils.definition_validator.parse_yaml_file") + def test_invalid_path(self, parse_yaml_file_mock): + parse_yaml_file_mock.side_effect = [{"A": 1}, {"A": 1}] + self.path.exists.return_value = False + + validator = DefinitionValidator(self.path, detect_change=False, initialize_data=False) + self.assertFalse(validator.validate()) + self.assertFalse(validator.validate()) + + @patch("samcli.lib.utils.definition_validator.parse_yaml_file") + def test_no_detect_change_valid(self, parse_yaml_file_mock): + parse_yaml_file_mock.side_effect = [{"A": 1}, {"A": 1}] + + validator = DefinitionValidator(self.path, detect_change=False, initialize_data=False) + self.assertTrue(validator.validate()) + self.assertTrue(validator.validate()) + + @patch("samcli.lib.utils.definition_validator.parse_yaml_file") + def test_no_detect_change_invalid(self, parse_yaml_file_mock): + parse_yaml_file_mock.side_effect = [ValueError(), {"A": 1}] + + validator = DefinitionValidator(self.path, detect_change=False, initialize_data=False) + self.assertFalse(validator.validate()) + self.assertTrue(validator.validate()) + + @patch("samcli.lib.utils.definition_validator.parse_yaml_file") + def test_detect_change_valid(self, parse_yaml_file_mock): + parse_yaml_file_mock.side_effect = [{"A": 1}, {"B": 1}] + + validator = DefinitionValidator(self.path, detect_change=True, initialize_data=False) + self.assertTrue(validator.validate()) + self.assertTrue(validator.validate()) + + @patch("samcli.lib.utils.definition_validator.parse_yaml_file") + def test_detect_change_invalid(self, parse_yaml_file_mock): + parse_yaml_file_mock.side_effect = [{"A": 1}, {"A": 1}, ValueError(), {"B": 1}] + + validator = DefinitionValidator(self.path, detect_change=True, initialize_data=False) + self.assertTrue(validator.validate()) + self.assertFalse(validator.validate()) + self.assertFalse(validator.validate()) + self.assertTrue(validator.validate()) + + @patch("samcli.lib.utils.definition_validator.parse_yaml_file") + def test_detect_change_initialize(self, parse_yaml_file_mock): + parse_yaml_file_mock.side_effect = [{"A": 1}, {"A": 1}, ValueError(), {"B": 1}] + + validator = DefinitionValidator(self.path, detect_change=True, initialize_data=True) + self.assertFalse(validator.validate()) + self.assertFalse(validator.validate()) + self.assertTrue(validator.validate()) diff --git a/tests/unit/lib/utils/test_handler_observer.py b/tests/unit/lib/utils/test_handler_observer.py new file mode 100644 index 0000000000..5d7041b6db --- /dev/null +++ b/tests/unit/lib/utils/test_handler_observer.py @@ -0,0 +1,145 @@ +import re +from unittest.case import TestCase +from unittest.mock import MagicMock, patch, ANY +from samcli.lib.utils.path_observer import HandlerObserver, PathHandler, StaticFolderWrapper + + +class TestPathHandler(TestCase): + def test_init(self): + handler_mock = MagicMock() + path_mock = MagicMock() + create_mock = MagicMock() + delete_mock = MagicMock() + bundle = PathHandler(handler_mock, path_mock, True, True, create_mock, delete_mock) + + self.assertEqual(bundle.event_handler, handler_mock) + self.assertEqual(bundle.path, path_mock) + self.assertEqual(bundle.self_create, create_mock) + self.assertEqual(bundle.self_delete, delete_mock) + self.assertTrue(bundle.recursive) + self.assertTrue(bundle.static_folder) + + +class TestStaticFolderWrapper(TestCase): + def setUp(self) -> None: + self.observer = MagicMock() + self.path_handler = MagicMock() + self.initial_watch = MagicMock() + self.wrapper = StaticFolderWrapper(self.observer, self.initial_watch, self.path_handler) + + def test_on_parent_change_on_delete(self): + watch_mock = MagicMock() + self.wrapper._watch = watch_mock + self.wrapper._path_handler.path.exists.return_value = False + + self.wrapper._on_parent_change(MagicMock()) + + self.path_handler.self_delete.assert_called_once_with() + self.observer.unschedule.assert_called_once_with(watch_mock) + self.assertIsNone(self.wrapper._watch) + + def test_on_parent_change_on_create(self): + watch_mock = MagicMock() + self.observer.schedule_handler.return_value = watch_mock + + self.wrapper._watch = None + self.wrapper._path_handler.path.exists.return_value = True + + self.wrapper._on_parent_change(MagicMock()) + + self.path_handler.self_create.assert_called_once_with() + self.observer.schedule_handler.assert_called_once_with(self.wrapper._path_handler) + self.assertEqual(self.wrapper._watch, watch_mock) + + @patch("samcli.lib.utils.path_observer.RegexMatchingEventHandler") + @patch("samcli.lib.utils.path_observer.PathHandler") + def test_get_dir_parent_path_handler(self, path_handler_mock, event_handler_mock): + path_mock = MagicMock() + path_mock.resolve.return_value.parent = "/parent/" + path_mock.resolve.return_value.__str__.return_value = "/parent/dir/" + self.path_handler.path = path_mock + + event_handler = MagicMock() + event_handler_mock.return_value = event_handler + path_handler = MagicMock() + path_handler_mock.return_value = path_handler + result = self.wrapper.get_dir_parent_path_handler() + + self.assertEqual(result, path_handler) + path_handler_mock.assert_called_once_with(path="/parent/", event_handler=event_handler) + escaped_path = re.escape("/parent/dir/") + event_handler_mock.assert_called_once_with( + regexes=[f"^{escaped_path}$"], ignore_regexes=[], ignore_directories=False, case_sensitive=True + ) + + +class TestHandlerObserver(TestCase): + def setUp(self) -> None: + self.observer = HandlerObserver() + + def test_schedule_handlers(self): + bundle_1 = MagicMock() + bundle_2 = MagicMock() + watch_1 = MagicMock() + watch_2 = MagicMock() + + schedule_handler_mock = MagicMock() + schedule_handler_mock.side_effect = [watch_1, watch_2] + self.observer.schedule_handler = schedule_handler_mock + result = self.observer.schedule_handlers([bundle_1, bundle_2]) + self.assertEqual(result, [watch_1, watch_2]) + schedule_handler_mock.assert_any_call(bundle_1) + schedule_handler_mock.assert_any_call(bundle_2) + + @patch("samcli.lib.utils.path_observer.StaticFolderWrapper") + def test_schedule_handler_not_static(self, wrapper_mock: MagicMock): + bundle = MagicMock() + event_handler = MagicMock() + bundle.event_handler = event_handler + bundle.path = "dir" + bundle.recursive = True + bundle.static_folder = False + watch = MagicMock() + + schedule_mock = MagicMock() + schedule_mock.return_value = watch + self.observer.schedule = schedule_mock + + result = self.observer.schedule_handler(bundle) + + self.assertEqual(result, watch) + schedule_mock.assert_any_call(bundle.event_handler, "dir", True) + wrapper_mock.assert_not_called() + + @patch("samcli.lib.utils.path_observer.StaticFolderWrapper") + def test_schedule_handler_static(self, wrapper_mock: MagicMock): + bundle = MagicMock() + event_handler = MagicMock() + bundle.event_handler = event_handler + bundle.path = "dir" + bundle.recursive = True + bundle.static_folder = True + watch = MagicMock() + + parent_bundle = MagicMock() + event_handler = MagicMock() + parent_bundle.event_handler = event_handler + parent_bundle.path = "parent" + parent_bundle.recursive = False + parent_bundle.static_folder = False + parent_watch = MagicMock() + + schedule_mock = MagicMock() + schedule_mock.side_effect = [watch, parent_watch] + self.observer.schedule = schedule_mock + + wrapper = MagicMock() + wrapper_mock.return_value = wrapper + wrapper.get_dir_parent_path_handler.return_value = parent_bundle + + result = self.observer.schedule_handler(bundle) + + self.assertEqual(result, parent_watch) + schedule_mock.assert_any_call(bundle.event_handler, "dir", True) + schedule_mock.assert_any_call(parent_bundle.event_handler, "parent", False) + wrapper_mock.assert_called_once_with(self.observer, watch, bundle) diff --git a/tests/unit/lib/utils/test_hash.py b/tests/unit/lib/utils/test_hash.py index 388b3c96da..6aa29e6da6 100644 --- a/tests/unit/lib/utils/test_hash.py +++ b/tests/unit/lib/utils/test_hash.py @@ -1,3 +1,4 @@ +import hashlib import os import shutil import tempfile @@ -94,6 +95,16 @@ def test_dir_hash_with_ignore_list(self): checksum_after_with_ignore_list = dir_checksum(os.path.dirname(_file.name), ignore_list=[".aws-sam"]) self.assertEqual(checksum_before, checksum_after_with_ignore_list) + def test_hashing_method(self): + _file = tempfile.NamedTemporaryFile(delete=False, dir=self.temp_dir) + _file.write(b"Testfile") + _file.close() + checksum_sha256 = dir_checksum(os.path.dirname(_file.name), hash_generator=hashlib.sha256()) + checksum_md5 = dir_checksum(os.path.dirname(_file.name), hashlib.md5()) + checksum_default = dir_checksum(os.path.dirname(_file.name)) + self.assertEqual(checksum_default, checksum_md5) + self.assertNotEqual(checksum_md5, checksum_sha256) + def test_dir_cyclic_links(self): _file = tempfile.NamedTemporaryFile(delete=False, dir=self.temp_dir) _file.write(b"Testfile") diff --git a/tests/unit/lib/utils/test_lock_distributor.py b/tests/unit/lib/utils/test_lock_distributor.py new file mode 100644 index 0000000000..f57ba4e1ed --- /dev/null +++ b/tests/unit/lib/utils/test_lock_distributor.py @@ -0,0 +1,103 @@ +from unittest import TestCase +from unittest.mock import MagicMock, call, patch +from samcli.lib.utils.lock_distributor import LockChain, LockDistributor, LockDistributorType + + +class TestLockChain(TestCase): + def test_aquire_order(self): + locks = {"A": MagicMock(), "B": MagicMock(), "C": MagicMock()} + call_mock = MagicMock() + call_mock.a = locks["A"] + call_mock.b = locks["B"] + call_mock.c = locks["C"] + lock_chain = LockChain(locks) + lock_chain.acquire() + call_mock.assert_has_calls([call.a.acquire(), call.b.acquire(), call.c.acquire()]) + + def test_aquire_order_shuffled(self): + locks = {"A": MagicMock(), "C": MagicMock(), "B": MagicMock()} + call_mock = MagicMock() + call_mock.a = locks["A"] + call_mock.b = locks["B"] + call_mock.c = locks["C"] + lock_chain = LockChain(locks) + lock_chain.acquire() + call_mock.assert_has_calls([call.a.acquire(), call.b.acquire(), call.c.acquire()]) + + def test_release_order(self): + locks = {"A": MagicMock(), "B": MagicMock(), "C": MagicMock()} + call_mock = MagicMock() + call_mock.a = locks["A"] + call_mock.b = locks["B"] + call_mock.c = locks["C"] + lock_chain = LockChain(locks) + lock_chain.release() + call_mock.assert_has_calls([call.a.release(), call.b.release(), call.c.release()]) + + def test_release_order_shuffled(self): + locks = {"A": MagicMock(), "C": MagicMock(), "B": MagicMock()} + call_mock = MagicMock() + call_mock.a = locks["A"] + call_mock.b = locks["B"] + call_mock.c = locks["C"] + lock_chain = LockChain(locks) + lock_chain.release() + call_mock.assert_has_calls([call.a.release(), call.b.release(), call.c.release()]) + + def test_with(self): + locks = {"A": MagicMock(), "C": MagicMock(), "B": MagicMock()} + call_mock = MagicMock() + call_mock.a = locks["A"] + call_mock.b = locks["B"] + call_mock.c = locks["C"] + with LockChain(locks) as _: + call_mock.assert_has_calls([call.a.acquire(), call.b.acquire(), call.c.acquire()]) + call_mock.assert_has_calls( + [call.a.acquire(), call.b.acquire(), call.c.acquire(), call.a.release(), call.b.release(), call.c.release()] + ) + + +class TestLockDistributor(TestCase): + @patch("samcli.lib.utils.lock_distributor.threading.Lock") + @patch("samcli.lib.utils.lock_distributor.multiprocessing.Lock") + def test_thread_get_locks(self, process_lock_mock, thread_lock_mock): + locks = [MagicMock(), MagicMock(), MagicMock(), MagicMock()] + thread_lock_mock.side_effect = locks + distributor = LockDistributor(LockDistributorType.THREAD, None) + keys = ["A", "B", "C"] + result = distributor.get_locks(keys) + + self.assertEqual(result["A"], locks[1]) + self.assertEqual(result["B"], locks[2]) + self.assertEqual(result["C"], locks[3]) + self.assertEqual(distributor.get_locks(keys)["A"], locks[1]) + + @patch("samcli.lib.utils.lock_distributor.threading.Lock") + @patch("samcli.lib.utils.lock_distributor.multiprocessing.Lock") + def test_process_get_locks(self, process_lock_mock, thread_lock_mock): + locks = [MagicMock(), MagicMock(), MagicMock(), MagicMock()] + process_lock_mock.side_effect = locks + distributor = LockDistributor(LockDistributorType.PROCESS, None) + keys = ["A", "B", "C"] + result = distributor.get_locks(keys) + + self.assertEqual(result["A"], locks[1]) + self.assertEqual(result["B"], locks[2]) + self.assertEqual(result["C"], locks[3]) + self.assertEqual(distributor.get_locks(keys)["A"], locks[1]) + + @patch("samcli.lib.utils.lock_distributor.threading.Lock") + @patch("samcli.lib.utils.lock_distributor.multiprocessing.Lock") + def test_process_manager_get_locks(self, process_lock_mock, thread_lock_mock): + manager_mock = MagicMock() + locks = [MagicMock(), MagicMock(), MagicMock(), MagicMock()] + manager_mock.dict.return_value = dict() + manager_mock.Lock.side_effect = locks + distributor = LockDistributor(LockDistributorType.PROCESS, manager_mock) + keys = ["A", "B", "C"] + result = distributor.get_locks(keys) + + self.assertEqual(result["A"], locks[1]) + self.assertEqual(result["B"], locks[2]) + self.assertEqual(result["C"], locks[3]) + self.assertEqual(distributor.get_locks(keys)["A"], locks[1]) diff --git a/tests/unit/lib/utils/test_resource_trigger.py b/tests/unit/lib/utils/test_resource_trigger.py new file mode 100644 index 0000000000..8feff30b71 --- /dev/null +++ b/tests/unit/lib/utils/test_resource_trigger.py @@ -0,0 +1,258 @@ +import re +from parameterized import parameterized +from unittest.case import TestCase +from unittest.mock import MagicMock, patch, ANY +from samcli.lib.utils.resource_trigger import ( + CodeResourceTrigger, + DefinitionCodeTrigger, + LambdaFunctionCodeTrigger, + LambdaImageCodeTrigger, + LambdaLayerCodeTrigger, + LambdaZipCodeTrigger, + ResourceTrigger, + TemplateTrigger, +) +from samcli.local.lambdafn.exceptions import FunctionNotFound, ResourceNotFound +from samcli.lib.providers.exceptions import MissingLocalDefinition +from samcli.lib.providers.provider import ResourceIdentifier + + +class TestResourceTrigger(TestCase): + @patch("samcli.lib.utils.resource_trigger.PathHandler") + @patch("samcli.lib.utils.resource_trigger.RegexMatchingEventHandler") + @patch("samcli.lib.utils.resource_trigger.Path") + def test_single_file_path_handler(self, path_mock, handler_mock, bundle_mock): + path = MagicMock() + path_mock.return_value = path + file_path = MagicMock() + file_path.__str__.return_value = "/parent/file" + + parent_path = MagicMock() + parent_path.__str__.return_value = "/parent/" + + file_path.parent = parent_path + + path.resolve.return_value = file_path + + ResourceTrigger.get_single_file_path_handler("/parent/file") + + path_mock.assert_called_once_with("/parent/file") + escaped_path = re.escape("/parent/file") + handler_mock.assert_called_once_with( + regexes=[f"^{escaped_path}$"], ignore_regexes=[], ignore_directories=True, case_sensitive=True + ) + bundle_mock.assert_called_once_with(path=parent_path, event_handler=handler_mock.return_value, recursive=False) + + @patch("samcli.lib.utils.resource_trigger.PathHandler") + @patch("samcli.lib.utils.resource_trigger.PatternMatchingEventHandler") + @patch("samcli.lib.utils.resource_trigger.Path") + def test_dir_path_handler(self, path_mock, handler_mock, bundle_mock): + path = MagicMock() + path_mock.return_value = path + folder_path = MagicMock() + + path.resolve.return_value = folder_path + + ResourceTrigger.get_dir_path_handler("/parent/folder/") + + path_mock.assert_called_once_with("/parent/folder/") + handler_mock.assert_called_once_with( + patterns=["*"], ignore_patterns=[], ignore_directories=False, case_sensitive=True + ) + bundle_mock.assert_called_once_with( + path=folder_path, event_handler=handler_mock.return_value, recursive=True, static_folder=True + ) + + +class TestTemplateTrigger(TestCase): + @patch("samcli.lib.utils.resource_trigger.DefinitionValidator") + @patch("samcli.lib.utils.resource_trigger.Path") + @patch("samcli.lib.utils.resource_trigger.ResourceTrigger.get_single_file_path_handler") + def test_get_path_handler(self, single_file_handler_mock, path_mock, validator_mock): + trigger = TemplateTrigger("template.yaml", MagicMock()) + result = trigger.get_path_handlers() + self.assertEqual(result, [single_file_handler_mock.return_value]) + self.assertEqual(single_file_handler_mock.return_value.event_handler.on_any_event, trigger._validator_wrapper) + + @patch("samcli.lib.utils.resource_trigger.DefinitionValidator") + @patch("samcli.lib.utils.resource_trigger.Path") + def test_validator_wrapper(self, path_mock, validator_mock): + on_template_change_mock = MagicMock() + event_mock = MagicMock() + validator_mock.return_value.validate.return_value = True + trigger = TemplateTrigger("template.yaml", on_template_change_mock) + trigger._validator_wrapper(event_mock) + on_template_change_mock.assert_called_once_with(event_mock) + + +class TestCodeResourceTrigger(TestCase): + @patch.multiple(CodeResourceTrigger, __abstractmethods__=set()) + @patch("samcli.lib.utils.resource_trigger.get_resource_by_id") + def test_init(self, get_resource_by_id_mock): + stacks = [MagicMock(), MagicMock()] + on_code_change_mock = MagicMock() + trigger = CodeResourceTrigger(ResourceIdentifier("A"), stacks, on_code_change_mock) + self.assertEqual(trigger._resource, get_resource_by_id_mock.return_value) + self.assertEqual(trigger._on_code_change, on_code_change_mock) + + @patch.multiple(CodeResourceTrigger, __abstractmethods__=set()) + @patch("samcli.lib.utils.resource_trigger.get_resource_by_id") + def test_init_invalid(self, get_resource_by_id_mock): + stacks = [MagicMock(), MagicMock()] + on_code_change_mock = MagicMock() + get_resource_by_id_mock.return_value = None + + with self.assertRaises(ResourceNotFound): + CodeResourceTrigger(ResourceIdentifier("A"), stacks, on_code_change_mock) + + +class TestLambdaFunctionCodeTrigger(TestCase): + @patch.multiple(LambdaFunctionCodeTrigger, __abstractmethods__=set()) + @patch("samcli.lib.utils.resource_trigger.SamFunctionProvider") + @patch("samcli.lib.utils.resource_trigger.get_resource_by_id") + def test_init(self, get_resource_by_id_mock, function_provider_mock): + stacks = [MagicMock(), MagicMock()] + on_code_change_mock = MagicMock() + function_mock = function_provider_mock.return_value.get.return_value + + code_uri_mock = MagicMock() + LambdaFunctionCodeTrigger._get_code_uri = code_uri_mock + + trigger = LambdaFunctionCodeTrigger(ResourceIdentifier("A"), stacks, on_code_change_mock) + self.assertEqual(trigger._function, function_mock) + self.assertEqual(trigger._code_uri, code_uri_mock.return_value) + + @patch.multiple(LambdaFunctionCodeTrigger, __abstractmethods__=set()) + @patch("samcli.lib.utils.resource_trigger.SamFunctionProvider") + @patch("samcli.lib.utils.resource_trigger.get_resource_by_id") + def test_init_invalid(self, get_resource_by_id_mock, function_provider_mock): + stacks = [MagicMock(), MagicMock()] + on_code_change_mock = MagicMock() + function_provider_mock.return_value.get.return_value = None + + code_uri_mock = MagicMock() + LambdaFunctionCodeTrigger._get_code_uri = code_uri_mock + + with self.assertRaises(FunctionNotFound): + LambdaFunctionCodeTrigger(ResourceIdentifier("A"), stacks, on_code_change_mock) + + @patch.multiple(LambdaFunctionCodeTrigger, __abstractmethods__=set()) + @patch("samcli.lib.utils.resource_trigger.ResourceTrigger.get_dir_path_handler") + @patch("samcli.lib.utils.resource_trigger.SamFunctionProvider") + @patch("samcli.lib.utils.resource_trigger.get_resource_by_id") + def test_get_path_handlers(self, get_resource_by_id_mock, function_provider_mock, get_dir_path_handler_mock): + stacks = [MagicMock(), MagicMock()] + on_code_change_mock = MagicMock() + function_mock = function_provider_mock.return_value.get.return_value + + code_uri_mock = MagicMock() + LambdaFunctionCodeTrigger._get_code_uri = code_uri_mock + + bundle = MagicMock() + get_dir_path_handler_mock.return_value = bundle + + trigger = LambdaFunctionCodeTrigger(ResourceIdentifier("A"), stacks, on_code_change_mock) + result = trigger.get_path_handlers() + + self.assertEqual(result, [bundle]) + self.assertEqual(bundle.self_create, on_code_change_mock) + self.assertEqual(bundle.self_delete, on_code_change_mock) + self.assertEqual(bundle.event_handler.on_any_event, on_code_change_mock) + + +class TestLambdaZipCodeTrigger(TestCase): + @patch("samcli.lib.utils.resource_trigger.SamFunctionProvider") + @patch("samcli.lib.utils.resource_trigger.get_resource_by_id") + def test_get_code_uri(self, get_resource_by_id_mock, function_provider_mock): + stacks = [MagicMock(), MagicMock()] + on_code_change_mock = MagicMock() + function_mock = function_provider_mock.return_value.get.return_value + trigger = LambdaZipCodeTrigger(ResourceIdentifier("A"), stacks, on_code_change_mock) + result = trigger._get_code_uri() + self.assertEqual(result, function_mock.codeuri) + + +class TestLambdaImageCodeTrigger(TestCase): + @patch("samcli.lib.utils.resource_trigger.SamFunctionProvider") + @patch("samcli.lib.utils.resource_trigger.get_resource_by_id") + def test_get_code_uri(self, get_resource_by_id_mock, function_provider_mock): + stacks = [MagicMock(), MagicMock()] + on_code_change_mock = MagicMock() + function_mock = function_provider_mock.return_value.get.return_value + trigger = LambdaImageCodeTrigger(ResourceIdentifier("A"), stacks, on_code_change_mock) + result = trigger._get_code_uri() + self.assertEqual(result, function_mock.metadata.get.return_value) + + +class TestLambdaLayerCodeTrigger(TestCase): + @patch("samcli.lib.utils.resource_trigger.SamLayerProvider") + @patch("samcli.lib.utils.resource_trigger.get_resource_by_id") + def test_init(self, get_resource_by_id_mock, layer_provider_mock): + stacks = [MagicMock(), MagicMock()] + on_code_change_mock = MagicMock() + layer_mock = layer_provider_mock.return_value.get.return_value + + trigger = LambdaLayerCodeTrigger(ResourceIdentifier("A"), stacks, on_code_change_mock) + self.assertEqual(trigger._layer, layer_mock) + self.assertEqual(trigger._code_uri, layer_mock.codeuri) + + @patch("samcli.lib.utils.resource_trigger.ResourceTrigger.get_dir_path_handler") + @patch("samcli.lib.utils.resource_trigger.SamLayerProvider") + @patch("samcli.lib.utils.resource_trigger.get_resource_by_id") + def test_get_path_handlers(self, get_resource_by_id_mock, layer_provider_mock, get_dir_path_handler_mock): + stacks = [MagicMock(), MagicMock()] + on_code_change_mock = MagicMock() + layer_mock = layer_provider_mock.return_value.get.return_value + + bundle = MagicMock() + get_dir_path_handler_mock.return_value = bundle + + trigger = LambdaLayerCodeTrigger(ResourceIdentifier("A"), stacks, on_code_change_mock) + result = trigger.get_path_handlers() + + self.assertEqual(result, [bundle]) + self.assertEqual(bundle.self_create, on_code_change_mock) + self.assertEqual(bundle.self_delete, on_code_change_mock) + self.assertEqual(bundle.event_handler.on_any_event, on_code_change_mock) + + +class TestDefinitionCodeTrigger(TestCase): + @patch("samcli.lib.utils.resource_trigger.DefinitionValidator") + @patch("samcli.lib.utils.resource_trigger.Path") + @patch("samcli.lib.utils.resource_trigger.ResourceTrigger.get_single_file_path_handler") + @patch("samcli.lib.utils.resource_trigger.get_resource_by_id") + def test_get_path_handler(self, get_resource_by_id_mock, single_file_handler_mock, path_mock, validator_mock): + stacks = [MagicMock(), MagicMock()] + resource = {"Properties": {"DefinitionUri": "abc"}} + get_resource_by_id_mock.return_value = resource + trigger = DefinitionCodeTrigger("TestApi", "AWS::Serverless::Api", stacks, MagicMock()) + result = trigger.get_path_handlers() + self.assertEqual(result, [single_file_handler_mock.return_value]) + self.assertEqual(single_file_handler_mock.return_value.event_handler.on_any_event, trigger._validator_wrapper) + + @patch("samcli.lib.utils.resource_trigger.DefinitionValidator") + @patch("samcli.lib.utils.resource_trigger.Path") + @patch("samcli.lib.utils.resource_trigger.ResourceTrigger.get_single_file_path_handler") + @patch("samcli.lib.utils.resource_trigger.get_resource_by_id") + def test_get_path_handler_missing_definition( + self, get_resource_by_id_mock, single_file_handler_mock, path_mock, validator_mock + ): + stacks = [MagicMock(), MagicMock()] + resource = {"Properties": {"Field": "abc"}} + get_resource_by_id_mock.return_value = resource + with self.assertRaises(MissingLocalDefinition): + trigger = DefinitionCodeTrigger("TestApi", "AWS::Serverless::Api", stacks, MagicMock()) + + @patch("samcli.lib.utils.resource_trigger.DefinitionValidator") + @patch("samcli.lib.utils.resource_trigger.Path") + @patch("samcli.lib.utils.resource_trigger.get_resource_by_id") + def test_validator_wrapper(self, get_resource_by_id_mock, path_mock, validator_mock): + stacks = [MagicMock(), MagicMock()] + on_definition_change_mock = MagicMock() + event_mock = MagicMock() + validator_mock.return_value.validate.return_value = True + resource = {"Properties": {"DefinitionUri": "abc"}} + get_resource_by_id_mock.return_value = resource + trigger = DefinitionCodeTrigger("TestApi", "AWS::Serverless::Api", stacks, on_definition_change_mock) + trigger._validator_wrapper(event_mock) + on_definition_change_mock.assert_called_once_with(event_mock) diff --git a/tests/unit/lib/utils/test_resource_type_based_factory.py b/tests/unit/lib/utils/test_resource_type_based_factory.py new file mode 100644 index 0000000000..302e91f4d6 --- /dev/null +++ b/tests/unit/lib/utils/test_resource_type_based_factory.py @@ -0,0 +1,48 @@ +from samcli.lib.providers.provider import ResourceIdentifier +from samcli.lib.utils.resource_type_based_factory import ResourceTypeBasedFactory +from unittest import TestCase +from unittest.mock import ANY, MagicMock, call, patch + + +class TestResourceTypeBasedFactory(TestCase): + def setUp(self): + self.abstract_method_patch = patch.multiple(ResourceTypeBasedFactory, __abstractmethods__=set()) + self.abstract_method_patch.start() + self.stacks = [MagicMock(), MagicMock()] + self.factory = ResourceTypeBasedFactory(self.stacks) + self.function_generator_mock = MagicMock() + self.layer_generator_mock = MagicMock() + self.factory._get_generator_mapping = MagicMock() + self.factory._get_generator_mapping.return_value = { + "AWS::Lambda::Function": self.function_generator_mock, + "AWS::Lambda::LayerVersion": self.layer_generator_mock, + } + + def tearDown(self): + self.abstract_method_patch.stop() + + @patch("samcli.lib.utils.resource_type_based_factory.get_resource_by_id") + def test_get_generator_function_valid(self, get_resource_by_id_mock): + resource = {"Type": "AWS::Lambda::Function"} + get_resource_by_id_mock.return_value = resource + + generator = self.factory._get_generator_function(ResourceIdentifier("Resource1")) + self.assertEqual(generator, self.function_generator_mock) + + @patch("samcli.lib.utils.resource_type_based_factory.get_resource_by_id") + def test_get_generator_function_unknown_type(self, get_resource_by_id_mock): + resource = {"Type": "AWS::Unknown::Type"} + get_resource_by_id_mock.return_value = resource + + generator = self.factory._get_generator_function(ResourceIdentifier("Resource1")) + + self.assertEqual(None, generator) + + @patch("samcli.lib.utils.resource_type_based_factory.get_resource_by_id") + def test_get_generator_function_no_type(self, get_resource_by_id_mock): + resource = {"Properties": {}} + get_resource_by_id_mock.return_value = resource + + generator = self.factory._get_generator_function(ResourceIdentifier("Resource1")) + + self.assertEqual(None, generator) diff --git a/tests/unit/lib/utils/test_version_checker.py b/tests/unit/lib/utils/test_version_checker.py index 9318d0fc79..4a76b2ed41 100644 --- a/tests/unit/lib/utils/test_version_checker.py +++ b/tests/unit/lib/utils/test_version_checker.py @@ -107,27 +107,18 @@ def test_fetch_and_compare_versions_different(self, mock_click, get_mock): ] ) - @patch("samcli.cli.global_config.GlobalConfig._set_value") - @patch("samcli.cli.global_config.GlobalConfig._get_value") - def test_update_last_check_time(self, mock_gc_get_value, mock_gc_set_value): - mock_gc_get_value.return_value = None - global_config = GlobalConfig() - self.assertIsNone(global_config.last_version_check) - - update_last_check_time(global_config) - self.assertIsNotNone(global_config.last_version_check) - - mock_gc_set_value.assert_has_calls([call("lastVersionCheck", ANY)]) - - @patch("samcli.cli.global_config.GlobalConfig._set_value") - @patch("samcli.cli.global_config.GlobalConfig._get_value") + @patch("samcli.lib.utils.version_checker.GlobalConfig") + @patch("samcli.lib.utils.version_checker.datetime") + def test_update_last_check_time(self, mock_datetime, mock_gc): + mock_datetime.utcnow.return_value.timestamp.return_value = 12345 + update_last_check_time() + self.assertEqual(mock_gc.return_value.last_version_check, 12345) + + @patch("samcli.cli.global_config.GlobalConfig.set_value") + @patch("samcli.cli.global_config.GlobalConfig.get_value") def test_update_last_check_time_should_return_when_exception_is_raised(self, mock_gc_get_value, mock_gc_set_value): mock_gc_set_value.side_effect = Exception() - global_config = GlobalConfig() - update_last_check_time(global_config) - - def test_update_last_check_time_should_return_when_global_config_is_none(self): - update_last_check_time(None) + update_last_check_time() def test_last_check_time_none_should_return_true(self): self.assertTrue(is_version_check_overdue(None))