Skip to content

Commit

Permalink
shrink mypy whitelist for other modules (cookiecutter#2054)
Browse files Browse the repository at this point in the history
* shrink mypy whitelist for other modules

---------

Co-authored-by: Jens W. Klein <[email protected]>
  • Loading branch information
danieleades and jensens authored Apr 5, 2024
1 parent bd9206b commit c43c3c0
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 22 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ repos:
hooks:
- id: mypy
additional_dependencies:
- rich
- jinja2
- click
- types-python-slugify
Expand Down
9 changes: 6 additions & 3 deletions cookiecutter/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
import copy
import logging
import os
from typing import Any
from typing import TYPE_CHECKING, Any

import yaml

from cookiecutter.exceptions import ConfigDoesNotExistException, InvalidConfiguration

if TYPE_CHECKING:
from pathlib import Path

logger = logging.getLogger(__name__)

USER_CONFIG_PATH = os.path.expanduser('~/.cookiecutterrc')
Expand All @@ -37,7 +40,7 @@ def _expand_path(path: str) -> str:
return path


def merge_configs(default, overwrite):
def merge_configs(default: dict[str, Any], overwrite: dict[str, Any]) -> dict[str, Any]:
"""Recursively update a dict with the key/value pair of another.
Dict values that are dictionaries themselves will be updated, whilst
Expand All @@ -56,7 +59,7 @@ def merge_configs(default, overwrite):
return new_config


def get_config(config_path):
def get_config(config_path: Path | str) -> dict[str, Any]:
"""Retrieve the config from the specified path, returning a config dict."""
if not os.path.exists(config_path):
raise ConfigDoesNotExistException(f'Config file {config_path} does not exist.')
Expand Down
51 changes: 38 additions & 13 deletions cookiecutter/prompt.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
"""Functions for prompting the user for project info."""

from __future__ import annotations

import json
import os
import re
import sys
from collections import OrderedDict
from itertools import starmap
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Union

from jinja2 import Environment
from jinja2.exceptions import UndefinedError
from rich.prompt import Confirm, InvalidResponse, Prompt, PromptBase
from typing_extensions import TypeAlias

from cookiecutter.exceptions import UndefinedVariableInTemplate
from cookiecutter.utils import create_env_with_context, rmtree

if TYPE_CHECKING:
from jinja2 import Environment


def read_user_variable(var_name: str, default_value, prompts=None, prefix: str = ""):
"""Prompt user for variable and return the entered value or given default.
Expand Down Expand Up @@ -83,7 +89,7 @@ def read_repo_password(question: str) -> str:
return Prompt.ask(question, password=True)


def read_user_choice(var_name: str, options, prompts=None, prefix=""):
def read_user_choice(var_name: str, options, prompts=None, prefix: str = ""):
"""Prompt the user to choose from several options for the given variable.
The first item will be returned if no input happens.
Expand All @@ -103,7 +109,7 @@ def read_user_choice(var_name: str, options, prompts=None, prefix=""):

question = f"Select {var_name}"

choice_lines = starmap(
choice_lines: Iterator[str] = starmap(
" [bold magenta]{}[/] - [bold]{}[/]".format, choice_map.items()
)

Expand Down Expand Up @@ -162,12 +168,12 @@ class JsonPrompt(PromptBase[dict]):
validate_error_message = "[prompt.invalid] Please enter a valid JSON string"

@staticmethod
def process_response(value: str) -> dict:
def process_response(value: str) -> dict[str, Any]:
"""Convert choices to a dict."""
return process_json(value)


def read_user_dict(var_name: str, default_value, prompts=None, prefix=""):
def read_user_dict(var_name: str, default_value, prompts=None, prefix: str = ""):
"""Prompt the user to provide a dictionary of data.
:param var_name: Variable as specified in the context
Expand All @@ -190,7 +196,14 @@ def read_user_dict(var_name: str, default_value, prompts=None, prefix=""):
return user_value


def render_variable(env: Environment, raw, cookiecutter_dict):
_Raw: TypeAlias = Union[bool, Dict["_Raw", "_Raw"], List["_Raw"], str, None]


def render_variable(
env: Environment,
raw: _Raw,
cookiecutter_dict: dict[str, Any],
) -> str:
"""Render the next variable to be displayed in the user prompt.
Inside the prompting taken from the cookiecutter.json file, this renders
Expand Down Expand Up @@ -237,7 +250,9 @@ def _prompts_from_options(options: dict) -> dict:
return prompts


def prompt_choice_for_template(key, options, no_input):
def prompt_choice_for_template(
key: str, options: dict, no_input: bool
) -> OrderedDict[str, Any]:
"""Prompt user with a set of options to choose from.
:param no_input: Do not prompt for user input and return the first available option.
Expand All @@ -248,8 +263,14 @@ def prompt_choice_for_template(key, options, no_input):


def prompt_choice_for_config(
cookiecutter_dict, env, key, options, no_input: bool, prompts=None, prefix: str = ""
):
cookiecutter_dict: dict[str, Any],
env: Environment,
key: str,
options,
no_input: bool,
prompts=None,
prefix: str = "",
) -> OrderedDict[str, Any] | str:
"""Prompt user with a set of options to choose from.
:param no_input: Do not prompt for user input and return the first available option.
Expand All @@ -260,7 +281,9 @@ def prompt_choice_for_config(
return read_user_choice(key, rendered_options, prompts, prefix)


def prompt_for_config(context, no_input=False):
def prompt_for_config(
context: dict[str, Any], no_input: bool = False
) -> OrderedDict[str, Any]:
"""Prompt user to enter a new config.
:param dict context: Source for field names and sample values.
Expand Down Expand Up @@ -340,15 +363,17 @@ def prompt_for_config(context, no_input=False):
return cookiecutter_dict


def choose_nested_template(context: dict, repo_dir: str, no_input: bool = False) -> str:
def choose_nested_template(
context: dict[str, Any], repo_dir: Path | str, no_input: bool = False
) -> str:
"""Prompt user to select the nested template to use.
:param context: Source for field names and sample values.
:param repo_dir: Repository directory.
:param no_input: Do not prompt for user input and use only values from context.
:returns: Path to the selected template.
"""
cookiecutter_dict = OrderedDict([])
cookiecutter_dict: OrderedDict[str, Any] = OrderedDict([])
env = create_env_with_context(context)
prefix = ""
prompts = context['cookiecutter'].pop('__prompts__', {})
Expand Down Expand Up @@ -377,7 +402,7 @@ def choose_nested_template(context: dict, repo_dir: str, no_input: bool = False)
return f"{template_path}"


def prompt_and_delete(path, no_input=False) -> bool:
def prompt_and_delete(path: Path | str, no_input: bool = False) -> bool:
"""
Ask user if it's okay to delete the previously-downloaded file/directory.
Expand Down
6 changes: 0 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,9 @@ show_error_codes = true
files = "cookiecutter"
no_implicit_reexport = true

[[tool.mypy.overrides]]
module = [
"cookiecutter.exceptions",
]
disallow_untyped_defs = false

[[tool.mypy.overrides]]
module = [
"cookiecutter.config",
"cookiecutter.environment",
"cookiecutter.extensions",
"cookiecutter.main",
Expand Down

0 comments on commit c43c3c0

Please sign in to comment.