Skip to content

Commit

Permalink
Minor updates to scaffold.py (getmoto#4213)
Browse files Browse the repository at this point in the history
Co-authored-by: Karri Balk <[email protected]>
  • Loading branch information
kbalk and kbalk authored Aug 24, 2021
1 parent d278fd6 commit 180a487
Show file tree
Hide file tree
Showing 9 changed files with 106 additions and 90 deletions.
141 changes: 75 additions & 66 deletions scripts/scaffold.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
#!/usr/bin/env python
"""This script generates template codes and response body for specified boto3's operation and apply to appropriate files.
"""Generates template code and response body for specified boto3's operation.
You only have to select service and operation that you want to add.
This script looks at the botocore's definition file of specified service and operation, and auto-generates codes and reponses.
Basically, this script supports almost all services, as long as its protocol is `query`, `json` or `rest-json`.
Event if aws adds new services, this script will work as long as the protocol is known.
This script looks at the botocore's definition file of specified service and
operation, and auto-generates codes and reponses.
Basically, this script supports almost all services, as long as its
protocol is `query`, `json` or `rest-json`. Even if aws adds new
services, this script will work as long as the protocol is known.
TODO:
- This scripts don't generates functions in `responses.py` for `rest-json`, because I don't know the rule of it. want someone fix this.
- In some services's operations, this scripts might crash. Make new issue on github then.
- This scripts don't generates functions in `responses.py` for
`rest-json`, because I don't know the rule of it. want someone fix this.
- In some services's operations, this scripts might crash. Make new
issue on github then.
"""
import os
import re
Expand All @@ -19,16 +26,15 @@
import jinja2
from prompt_toolkit import prompt
from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.shortcuts import print_formatted_text

from botocore import xform_name
from botocore.session import Session
import boto3

from moto.core.responses import BaseResponse
from moto.core import BaseBackend
from implementation_coverage import get_moto_implementation
from inflection import singularize
from implementation_coverage import get_moto_implementation

TEMPLATE_DIR = os.path.join(os.path.dirname(__file__), "./template")

Expand All @@ -37,16 +43,16 @@


def print_progress(title, body, color):
click.secho(u"\t{}\t".format(title), fg=color, nl=False)
click.secho("\t{}\t".format(title), fg=color, nl=False)
click.echo(body)


def select_service_and_operation():
service_names = Session().get_available_services()
service_completer = WordCompleter(service_names)
service_name = prompt(u"Select service: ", completer=service_completer)
service_name = prompt("Select service: ", completer=service_completer)
if service_name not in service_names:
click.secho(u"{} is not valid service".format(service_name), fg="red")
click.secho("{} is not valid service".format(service_name), fg="red")
raise click.Abort()
moto_client = get_moto_implementation(service_name)
real_client = boto3.client(service_name, region_name="us-east-1")
Expand All @@ -56,19 +62,19 @@ def select_service_and_operation():
operation_names = [
xform_name(op) for op in real_client.meta.service_model.operation_names
]
for op in operation_names:
if moto_client and op in dir(moto_client):
implemented.append(op)
for operation in operation_names:
if moto_client and operation in dir(moto_client):
implemented.append(operation)
else:
not_implemented.append(op)
not_implemented.append(operation)
operation_completer = WordCompleter(operation_names)

click.echo("==Current Implementation Status==")
for operation_name in operation_names:
check = "X" if operation_name in implemented else " "
click.secho("[{}] {}".format(check, operation_name))
click.echo("=================================")
operation_name = prompt(u"Select Operation: ", completer=operation_completer)
operation_name = prompt("Select Operation: ", completer=operation_completer)

if operation_name not in operation_names:
click.secho("{} is not valid operation".format(operation_name), fg="red")
Expand All @@ -93,7 +99,7 @@ def get_test_dir(service):


def render_template(tmpl_dir, tmpl_filename, context, service, alt_filename=None):
is_test = True if "test" in tmpl_dir else False
is_test = "test" in tmpl_dir
rendered = (
jinja2.Environment(loader=jinja2.FileSystemLoader(tmpl_dir))
.get_template(tmpl_filename)
Expand All @@ -108,14 +114,14 @@ def render_template(tmpl_dir, tmpl_filename, context, service, alt_filename=None
print_progress("skip creating", filepath, "yellow")
else:
print_progress("creating", filepath, "green")
with open(filepath, "w") as f:
f.write(rendered)
with open(filepath, "w") as fhandle:
fhandle.write(rendered)


def append_mock_to_init_py(service):
path = os.path.join(os.path.dirname(__file__), "..", "moto", "__init__.py")
with open(path) as f:
lines = [_.replace("\n", "") for _ in f.readlines()]
with open(path) as fhandle:
lines = [_.replace("\n", "") for _ in fhandle.readlines()]

if any(_ for _ in lines if re.match("^mock_{}.*lazy_load(.*)$".format(service), _)):
return
Expand All @@ -130,20 +136,16 @@ def append_mock_to_init_py(service):
lines.insert(last_import_line_index + 1, new_line)

body = "\n".join(lines) + "\n"
with open(path, "w") as f:
f.write(body)
with open(path, "w") as fhandle:
fhandle.write(body)


def append_mock_dict_to_backends_py(service):
path = os.path.join(os.path.dirname(__file__), "..", "moto", "backends.py")
with open(path) as f:
lines = [_.replace("\n", "") for _ in f.readlines()]

if any(
_
for _ in lines
if re.match('.*"{}": {}_backends.*'.format(service, service), _)
):
with open(path) as fhandle:
lines = [_.replace("\n", "") for _ in fhandle.readlines()]

if any(_ for _ in lines if re.match(f'.*"{service}": {service}_backends.*', _)):
return
filtered_lines = [_ for _ in lines if re.match('.*".*":.*_backends.*', _)]
last_elem_line_index = lines.index(filtered_lines[-1])
Expand All @@ -157,11 +159,11 @@ def append_mock_dict_to_backends_py(service):
lines.insert(last_elem_line_index + 1, new_line)

body = "\n".join(lines) + "\n"
with open(path, "w") as f:
f.write(body)
with open(path, "w") as fhandle:
fhandle.write(body)


def initialize_service(service, operation, api_protocol):
def initialize_service(service, api_protocol):
"""create lib and test dirs if not exist"""
lib_dir = get_lib_dir(service)
test_dir = get_test_dir(service)
Expand Down Expand Up @@ -211,18 +213,18 @@ def initialize_service(service, operation, api_protocol):
append_mock_dict_to_backends_py(service)


def to_upper_camel_case(s):
return "".join([_.title() for _ in s.split("_")])
def to_upper_camel_case(string):
return "".join([_.title() for _ in string.split("_")])


def to_lower_camel_case(s):
words = s.split("_")
def to_lower_camel_case(string):
words = string.split("_")
return "".join(words[:1] + [_.title() for _ in words[1:]])


def to_snake_case(s):
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", s)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
def to_snake_case(string):
new_string = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", string)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", new_string).lower()


def get_operation_name_in_keys(operation_name, operation_keys):
Expand Down Expand Up @@ -274,15 +276,15 @@ def get_function_in_responses(service, operation, protocol):
get_escaped_service(service), operation
)
for input_name in input_names:
body += " {}={},\n".format(input_name, input_name)
body += f" {input_name}={input_name},\n"

body += " )\n"
if protocol == "query":
body += " template = self.response_template({}_TEMPLATE)\n".format(
operation.upper()
)
body += " return template.render({})\n".format(
", ".join(["{}={}".format(_, _) for _ in output_names])
", ".join([f"{n}={n}" for n in output_names])
)
elif protocol in ["json", "rest-json"]:
body += " # TODO: adjust response\n"
Expand Down Expand Up @@ -324,20 +326,24 @@ def get_function_in_models(service, operation):
return body


def _get_subtree(name, shape, replace_list, name_prefix=[]):
def _get_subtree(name, shape, replace_list, name_prefix=None):
if not name_prefix:
name_prefix = []

class_name = shape.__class__.__name__
if class_name in ("StringShape", "Shape"):
t = etree.Element(name)
tree = etree.Element(name)
if name_prefix:
t.text = "{{ %s.%s }}" % (name_prefix[-1], to_snake_case(name))
tree.text = "{{ %s.%s }}" % (name_prefix[-1], to_snake_case(name))
else:
t.text = "{{ %s }}" % to_snake_case(name)
return t
elif class_name in ("ListShape",):
tree.text = "{{ %s }}" % to_snake_case(name)
return tree

if class_name in ("ListShape",):
replace_list.append((name, name_prefix))
t = etree.Element(name)
tree = etree.Element(name)
t_member = etree.Element("member")
t.append(t_member)
tree.append(t_member)
for nested_name, nested_shape in shape.member.members.items():
t_member.append(
_get_subtree(
Expand All @@ -347,7 +353,7 @@ def _get_subtree(name, shape, replace_list, name_prefix=[]):
name_prefix + [singularize(name.lower())],
)
)
return t
return tree
raise ValueError("Not supported Shape")


Expand Down Expand Up @@ -417,8 +423,8 @@ def get_response_query_template(service, operation):


def insert_code_to_class(path, base_class, new_code):
with open(path) as f:
lines = [_.replace("\n", "") for _ in f.readlines()]
with open(path) as fhandle:
lines = [_.replace("\n", "") for _ in fhandle.readlines()]
mod_path = os.path.splitext(path)[0].replace("/", ".")
mod = importlib.import_module(mod_path)
clsmembers = inspect.getmembers(mod, inspect.isclass)
Expand All @@ -436,8 +442,8 @@ def insert_code_to_class(path, base_class, new_code):
lines = lines[:end_line_no] + func_lines + lines[end_line_no:]

body = "\n".join(lines) + "\n"
with open(path, "w") as f:
f.write(body)
with open(path, "w") as fhandle:
fhandle.write(body)


def insert_url(service, operation, api_protocol):
Expand All @@ -452,8 +458,8 @@ def insert_url(service, operation, api_protocol):
path = os.path.join(
os.path.dirname(__file__), "..", "moto", get_escaped_service(service), "urls.py"
)
with open(path) as f:
lines = [_.replace("\n", "") for _ in f.readlines()]
with open(path) as fhandle:
lines = [_.replace("\n", "") for _ in fhandle.readlines()]

if any(_ for _ in lines if re.match(uri, _)):
return
Expand All @@ -480,8 +486,8 @@ def insert_url(service, operation, api_protocol):
lines.insert(last_elem_line_index + 1, new_line)

body = "\n".join(lines) + "\n"
with open(path, "w") as f:
f.write(body)
with open(path, "w") as fhandle:
fhandle.write(body)


def insert_codes(service, operation, api_protocol):
Expand All @@ -495,11 +501,11 @@ def insert_codes(service, operation, api_protocol):
# insert template
if api_protocol == "query":
template = get_response_query_template(service, operation)
with open(responses_path) as f:
lines = [_[:-1] for _ in f.readlines()]
with open(responses_path) as fhandle:
lines = [_[:-1] for _ in fhandle.readlines()]
lines += template.splitlines()
with open(responses_path, "w") as f:
f.write("\n".join(lines))
with open(responses_path, "w") as fhandle:
fhandle.write("\n".join(lines))

# edit models.py
models_path = "moto/{}/models.py".format(get_escaped_service(service))
Expand All @@ -514,7 +520,7 @@ def insert_codes(service, operation, api_protocol):
def main():
service, operation = select_service_and_operation()
api_protocol = boto3.client(service)._service_model.metadata["protocol"]
initialize_service(service, operation, api_protocol)
initialize_service(service, api_protocol)

if api_protocol in ["query", "json", "rest-json"]:
insert_codes(service, operation, api_protocol)
Expand All @@ -525,7 +531,10 @@ def main():
"yellow",
)

click.echo('You will still need to add the mock into "__init__.py"'.format(service))
click.echo(
'You will still need to add the mock into "docs/index.rst" and '
'"IMPLEMENTATION_COVERAGE.md"'
)


if __name__ == "__main__":
Expand Down
4 changes: 1 addition & 3 deletions scripts/template/lib/__init__.py.j2
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import unicode_literals
"""{{ escaped_service }} module initialization; sets value for base decorator."""
from .models import {{ escaped_service }}_backends
from ..core.models import base_decorator

{{ escaped_service }}_backend = {{ escaped_service }}_backends['us-east-1']
mock_{{ escaped_service }} = base_decorator({{ escaped_service }}_backends)

4 changes: 2 additions & 2 deletions scripts/template/lib/exceptions.py.j2
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import unicode_literals
from moto.core.exceptions import RESTError
"""Exceptions raised by the {{ escaped_service }} service."""
from moto.core.exceptions import JsonRESTError


24 changes: 16 additions & 8 deletions scripts/template/lib/models.py.j2
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from __future__ import unicode_literals
"""{{ service_class }}Backend class with methods for supported APIs."""
from boto3 import Session

from moto.core import BaseBackend, BaseModel


class {{ service_class }}Backend(BaseBackend):

"""Implementation of {{ service_class }} APIs."""

def __init__(self, region_name=None):
super({{ service_class }}Backend, self).__init__()
self.region_name = region_name

def reset(self):
"""Re-initialize all attributes for this instance."""
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
Expand All @@ -17,9 +21,13 @@ class {{ service_class }}Backend(BaseBackend):


{{ escaped_service }}_backends = {}
for region in Session().get_available_regions("{{ service }}"):
{{ escaped_service }}_backends[region] = {{ service_class }}Backend()
for region in Session().get_available_regions("{{ service }}", partition_name="aws-us-gov"):
{{ escaped_service }}_backends[region] = {{ service_class }}Backend()
for region in Session().get_available_regions("{{ service }}", partition_name="aws-cn"):
{{ escaped_service }}_backends[region] = {{ service_class }}Backend()
for available_region in Session().get_available_regions("{{ service }}"):
{{ escaped_service }}_backends[available_region] = {{ service_class }}Backend()
for available_region in Session().get_available_regions(
"{{ service }}", partition_name="aws-us-gov"
):
{{ escaped_service }}_backends[available_region] = {{ service_class }}Backend()
for available_region in Session().get_available_regions(
"{{ service }}", partition_name="aws-cn"
):
{{ escaped_service }}_backends[available_region] = {{ service_class }}Backend()
Loading

0 comments on commit 180a487

Please sign in to comment.