Skip to content

Commit

Permalink
Clean up, add update command (huggingface#853)
Browse files Browse the repository at this point in the history
* Clean up, add update command

* Use args for all but default_config

* Call explicitly with args

* Update CLI docs
  • Loading branch information
muellerzr authored Nov 15, 2022
1 parent 71660af commit dd8f205
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 74 deletions.
24 changes: 23 additions & 1 deletion docs/source/package_reference/cli.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Create a default config file for Accelerate with only a few flags set.
**Usage**:

```bash
accelerate default-config [arguments]
accelerate config default [arguments]
```

**Optional Arguments**:
Expand All @@ -57,6 +57,28 @@ accelerate default-config [arguments]
* `-h`, `--help` (`bool`) -- Show a help message and exit
* `--mixed_precision {no,fp16,bf16}` (`str`) -- Whether or not to use mixed precision training. Choose between FP16 and BF16 (bfloat16) training. BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.

## accelerate config update

**Command**:

`accelerate config update` or `accelerate-config update`

Update an existing config file with the latest defaults while maintaining the old configuration.

**Usage**:

```bash
accelerate config update [arguments]
```

**Optional Arguments**:
* `--config_file CONFIG_FILE` (`str`) -- The path to the config file to update. Will default to a file named default_config.yaml in the cache location, which is the content
of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have such an environment variable, your cache directory
(`~/.cache` or the content of `XDG_CACHE_HOME`) suffixed with `huggingface`.

* `-h`, `--help` (`bool`) -- Show a help message and exit


## accelerate env

**Command**:
Expand Down
32 changes: 14 additions & 18 deletions src/accelerate/commands/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,40 +16,36 @@

import argparse

from .config import config_command, config_command_parser
from .config import config_command_parser
from .config_args import default_config_file, load_config_from_file # noqa: F401
from .default import default_command_parser, default_config_command


def filter_command_args(args: dict, args_prefix: str):
"Filters args while only keeping ones that are prefixed with `{args_prefix}.`"
new_args = argparse.Namespace()
for key, value in vars(args).items():
if key.startswith(args_prefix):
setattr(new_args, key.replace(f"{args_prefix}.", ""), value)
return new_args
from .default import default_command_parser
from .update import update_command_parser


def get_config_parser(subparsers=None):
parent_parser = argparse.ArgumentParser(add_help=False)
# The main config parser
config_parser = config_command_parser(subparsers)
# The subparser to add commands to
subcommands = config_parser.add_subparsers(title="subcommands", dest="subcommand")

# Then add other parsers with the parent parser
default_parser = default_command_parser(config_parser, parents=[parent_parser]) # noqa: F841
default_command_parser(subcommands, parents=[parent_parser])
update_command_parser(subcommands, parents=[parent_parser])

return config_parser


def main():
config_parser = get_config_parser()
args = config_parser.parse_args()
if not args.default:
args = filter_command_args(args, "config_args")
config_command(args)
elif args.default:
args = filter_command_args(args, "default_args")
default_config_command(args)

if not hasattr(args, "func"):
config_parser.print_help()
exit(1)

# Run
args.func(args)


if __name__ == "__main__":
Expand Down
18 changes: 4 additions & 14 deletions src/accelerate/commands/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,7 @@

from .cluster import get_cluster_input
from .config_args import cache_dir, default_config_file, default_yaml_config_file, load_config_from_file # noqa: F401
from .config_utils import ( # noqa: F401
GroupedAction,
SubcommandHelpFormatter,
_ask_field,
_ask_options,
_convert_compute_environment,
)
from .config_utils import _ask_field, _ask_options, _convert_compute_environment # noqa: F401
from .sagemaker import get_sagemaker_input


Expand All @@ -49,18 +43,13 @@ def get_user_input():

def config_command_parser(subparsers=None):
if subparsers is not None:
parser = subparsers.add_parser("config", description=description, formatter_class=SubcommandHelpFormatter)
parser = subparsers.add_parser("config", description=description)
else:
parser = argparse.ArgumentParser(
"Accelerate config command", description=description, formatter_class=SubcommandHelpFormatter
)
parser = argparse.ArgumentParser("Accelerate config command", description=description)

parser.add_argument(
"--config_file",
default=None,
dest="config_args.config_file",
metavar="CONFIG_FILE",
action=GroupedAction,
help=(
"The path to use to store the config file. Will default to a file named default_config.yaml in the cache "
"location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have "
Expand All @@ -87,6 +76,7 @@ def config_command(args):
config.to_json_file(config_file)
else:
config.to_yaml_file(config_file)
print(f"accelerate configuration saved at {config_file}")


def main():
Expand Down
18 changes: 0 additions & 18 deletions src/accelerate/commands/config/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,29 +88,11 @@ def _convert_yes_no_to_bool(value):
return {"yes": True, "no": False}[value.lower()]


class GroupedAction(argparse.Action):
"""
Filters arguments into seperate namespace groups based on the first part of the argument name.
"""

def __call__(self, parser, namespace, values, option_string=None):
group, dest = self.dest.split(".", 2)
groupspace = getattr(namespace, group, argparse.Namespace())
setattr(groupspace, dest, values)
setattr(namespace, group, groupspace)


class SubcommandHelpFormatter(argparse.RawDescriptionHelpFormatter):
"""
A custom formatter that will remove the usage line from the help message for subcommands.
"""

def _format_action(self, action):
parts = super()._format_action(action)
if action.nargs == argparse.PARSER:
parts = "\n".join(parts.split("\n")[1:])
return parts

def _format_usage(self, usage, actions, groups, prefix):
usage = super()._format_usage(usage, actions, groups, prefix)
usage = usage.replace("<command> [<args>] ", "")
Expand Down
37 changes: 14 additions & 23 deletions src/accelerate/commands/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
from pathlib import Path

import torch

from .config_args import ClusterConfig, default_json_config_file
from .config_utils import GroupedAction
from .config_utils import SubcommandHelpFormatter


description = "Create a default config file for Accelerate with only a few flags set."


def write_basic_config(mixed_precision="no", save_location: str = default_json_config_file, dynamo_backend="no"):
Expand All @@ -42,7 +44,7 @@ def write_basic_config(mixed_precision="no", save_location: str = default_json_c
print(
f"Configuration already exists at {save_location}, will not override. Run `accelerate config` manually or pass a different `save_location`."
)
return
return False
mixed_precision = mixed_precision.lower()
if mixed_precision not in ["no", "fp16", "bf16"]:
raise ValueError(f"`mixed_precision` should be one of 'no', 'fp16', or 'bf16'. Received {mixed_precision}")
Expand All @@ -64,20 +66,13 @@ def write_basic_config(mixed_precision="no", save_location: str = default_json_c
config["use_cpu"] = True
config["num_processes"] = 1
config["distributed_type"] = "NO"
if not path.exists():
config = ClusterConfig(**config)
config.to_json_file(path)
config = ClusterConfig(**config)
config.to_json_file(path)
return path


description = "Create a default config file for Accelerate with only a few flags set."


def default_command_parser(parser=None, parents=None):
if parser is None and parents is None:
parser = argparse.ArgumentParser(description=description)
else:
default_parser = parser.add_subparsers(title="subcommand {default}", dest="default", description=description)
parser = default_parser.add_parser("default", parents=parents)
def default_command_parser(parser, parents):
parser = parser.add_parser("default", parents=parents, help=description, formatter_class=SubcommandHelpFormatter)
parser.add_argument(
"--config_file",
default=default_json_config_file,
Expand All @@ -87,9 +82,7 @@ def default_command_parser(parser=None, parents=None):
"such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed "
"with 'huggingface'."
),
dest="default_args.save_location",
metavar="CONFIG_FILE",
action=GroupedAction,
dest="save_location",
)

parser.add_argument(
Expand All @@ -100,14 +93,12 @@ def default_command_parser(parser=None, parents=None):
"Choose between FP16 and BF16 (bfloat16) training. "
"BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.",
default="no",
dest="default_args.mixed_precision",
action=GroupedAction,
)
parser.set_defaults(func=default_config_command)
return parser


def default_config_command(args):
args = vars(args)
args.pop("func", None)
write_basic_config(**args)
config_file = write_basic_config(args.mixed_precision, args.save_location)
if config_file:
print(f"accelerate configuration saved at {config_file}")
63 changes: 63 additions & 0 deletions src/accelerate/commands/config/update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/usr/bin/env python

# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path

from .config_args import default_config_file, load_config_from_file
from .config_utils import SubcommandHelpFormatter


description = "Update an existing config file with the latest defaults while maintaining the old configuration."


def update_config(args):
"""
Update an existing config file with the latest defaults while maintaining the old configuration.
"""
config_file = args.config_file
if config_file is None and Path(default_config_file).exists():
config_file = default_config_file
elif not Path(config_file).exists():
raise ValueError(f"The passed config file located at {config_file} doesn't exist.")
config = load_config_from_file(config_file)

if config_file.endswith(".json"):
config.to_json_file(config_file)
else:
config.to_yaml_file(config_file)
return config_file


def update_command_parser(parser, parents):
parser = parser.add_parser("update", parents=parents, help=description, formatter_class=SubcommandHelpFormatter)
parser.add_argument(
"--config_file",
default=None,
help=(
"The path to the config file to update. Will default to a file named default_config.yaml in the cache "
"location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have "
"such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed "
"with 'huggingface'."
),
)

parser.set_defaults(func=update_config_command)
return parser


def update_config_command(args):
config_file = update_config(args)
print(f"Sucessfully updated the configuration file at {config_file}.")

0 comments on commit dd8f205

Please sign in to comment.