diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..485dee6 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.idea diff --git a/README.md b/README.md index 62fec9e..ea9d1d7 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,11 @@ -# LOMO -LOMO: LOw-Memory Optimization +# LOMO: LOw-Memory Optimization +This is the implementation for [Full Parameter Fine-Tuning for Large Language Models with Limited Resorcces](). + +--- +## Run the code +```shell +bash run.sh +``` + +## Reproduce our results +We provide the sampled datasets used in our experiments [here](https://drive.google.com/drive/folders/1zV7sXvU7YHKWyS3fYV0yyi7FyTjIpEuO?usp=sharing). Due to the limited computational resources, we reported the highest results obtained from experiments conducted with the same random seed (`42`). We acknolwedge this limitation in our work and plan to conduct repeated experiments in the next version to address it. \ No newline at end of file diff --git a/config/ds_config.json b/config/ds_config.json new file mode 100644 index 0000000..0b63fde --- /dev/null +++ b/config/ds_config.json @@ -0,0 +1,27 @@ +{ + + "bf16": { + "enabled": false + }, + "fp16": { + "enabled": true + }, + "zero_allow_untested_optimizer": true, + "zero_force_ds_cpu_optimizer": false, + + "zero_optimization": { + "stage": 3, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e8, + "stage3_max_live_parameters": 1e8, + "stage3_max_reuse_distance": 1e8, + "stage3_gather_16bit_weights_on_model_save": true + }, + + + "gradient_accumulation_steps": 1, + "steps_per_print": 2000, + "train_micro_batch_size_per_gpu": 2, + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/config/ds_config_lora.json b/config/ds_config_lora.json new file mode 100644 index 0000000..667083d --- /dev/null +++ b/config/ds_config_lora.json @@ -0,0 +1,27 @@ +{ + + "bf16": { + "enabled": true + }, + "fp16": { + "enabled": false + }, + "zero_allow_untested_optimizer": true, + "zero_force_ds_cpu_optimizer": false, + + "zero_optimization": { + "stage": 3, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e8, + "stage3_max_live_parameters": 1e8, + "stage3_max_reuse_distance": 1e8, + "stage3_gather_16bit_weights_on_model_save": true + }, + + + "gradient_accumulation_steps": 1, + "steps_per_print": 2000, + "train_micro_batch_size_per_gpu": 2, + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/config/hf_args_zero.yaml b/config/hf_args_zero.yaml new file mode 100644 index 0000000..0c7ed8b --- /dev/null +++ b/config/hf_args_zero.yaml @@ -0,0 +1,40 @@ +# model +model_name_or_path: '/home/ubuntu/projects/tunelite/cache/llama-7b' +# data +dataset_name: 'wsc' +refresh: false +data_tag: 'base' +train_on_inputs: false +data_max_length: 1024 +# training +# trainer +tag: 'inplace-sgd' +output_dir: 'outputs' +overwrite_output_dir: true +deepspeed: 'config/ds_config.json' +do_train: true +do_eval: true +evaluation_strategy: 'epoch' +per_device_train_batch_size: 2 +per_device_eval_batch_size: 2 +learning_rate: 0.05 +weight_decay: 0 +num_train_epochs: 10 +lr_scheduler_type: 'linear' +warmup: 0.1 +clip_grad_norm: 1.0 +log_level: 'info' +logging_steps: 1 +save_strategy: 'no' +save_total_limit: 0 +seed: 42 +#bf16: true +remove_unused_columns: false +load_best_model_at_end: false +metric_for_best_model: 'acc' +optim: 'sgd' +group_by_length: false +#report_to: 'wandb' +dataloader_pin_memory: false +gradient_checkpointing: true +predict_with_generate: true \ No newline at end of file diff --git a/config/hf_args_zero_lora.yaml b/config/hf_args_zero_lora.yaml new file mode 100644 index 0000000..c1c22c6 --- /dev/null +++ b/config/hf_args_zero_lora.yaml @@ -0,0 +1,49 @@ +# model +model_name_or_path: '/home/ubuntu/projects/tunelite/cache/llama-7b' +# data +dataset_name: 'wsc' +refresh: false +data_tag: 'base' +train_on_inputs: false +data_max_length: 1024 +# training +# trainer +peft_type: 'lora' +lora_only: false +hf_learning_rate: 0.0005 +hf_weight_decay: 0 +hf_lr_scheduler_type: 'linear' +hf_warmup: 0.05 +tag: 'lora-qv-r2-inplace-sgd' +output_dir: 'outputs' +overwrite_output_dir: true +deepspeed: 'config/ds_config_lora.json' +do_train: true +do_eval: true +evaluation_strategy: 'epoch' +per_device_train_batch_size: 2 +per_device_eval_batch_size: 2 +learning_rate: 0.005 +weight_decay: 0 +num_train_epochs: 10 +lr_scheduler_type: 'linear' +warmup: 0.05 +clip_grad_norm: 1.0 +#clip_grad_value: 1.0 +#clip_loss_value: 5.0 +log_level: 'info' +logging_steps: 1 +save_strategy: 'no' +save_total_limit: 0 +seed: 42 +#bf16: true +remove_unused_columns: false +load_best_model_at_end: false +metric_for_best_model: 'acc' +optim: 'sgd' +group_by_length: false +#report_to: 'wandb' +dataloader_pin_memory: false +gradient_checkpointing: true +predict_with_generate: true +lora_r: 2 \ No newline at end of file diff --git a/log/__init__.py b/log/__init__.py new file mode 100644 index 0000000..d1d95f2 --- /dev/null +++ b/log/__init__.py @@ -0,0 +1,8 @@ +__all__ = [ + 'logger', + "print" +] + +from .logger import logger +from .print import print + diff --git a/log/handler.py b/log/handler.py new file mode 100644 index 0000000..40931c2 --- /dev/null +++ b/log/handler.py @@ -0,0 +1,90 @@ +import logging +import sys +from logging import getLevelName + +try: + from tqdm.auto import tqdm +except ImportError: + tqdm = None + +__all__ = [] + +if tqdm is not None: + class TqdmLoggingHandler(logging.Handler): + def __init__(self, level=logging.INFO): + super().__init__(level) + + def emit(self, record): + try: + msg = self.format(record) + tqdm.write(msg) + self.flush() + except (KeyboardInterrupt, SystemExit): + raise + except: + self.handleError(record) +else: + class TqdmLoggingHandler(logging.StreamHandler): + def __init__(self, level=logging.INFO): + super().__init__(sys.stdout) + self.setLevel(level) + + +class StdoutStreamHandler(logging.StreamHandler): + """ + 重载 StreamHandler 使得替换 sys.stdout 的时候能够生效。 + + """ + def __init__(self): + super(StdoutStreamHandler, self).__init__() + + def flush(self): + """ + Flushes the stream. + """ + self.acquire() + try: + sys.stdout.flush() + finally: + self.release() + + def emit(self, record): + """ + Emit a record. + + If a formatter is specified, it is used to format the record. + The record is then written to the stream with a trailing newline. If + exception information is present, it is formatted using + traceback.print_exception and appended to the stream. If the stream + has an 'encoding' attribute, it is used to determine how to do the + output to the stream. + """ + try: + msg = self.format(record) + stream = sys.stdout + # issue 35046: merged two stream.writes into one. + stream.write(msg + self.terminator) + self.flush() + except RecursionError: # See issue 36272 + raise + except Exception: + self.handleError(record) + + def setStream(self, stream): + """ + Sets the StreamHandler's stream to the specified value, + if it is different. + + Returns the old stream, if the stream was changed, or None + if it wasn't. + """ + raise RuntimeError("Cannot set the stream of FStreamHandler.") + + def __repr__(self): + level = getLevelName(self.level) + name = getattr(sys.stdout, 'name', '') + # bpo-36015: name can be an int + name = str(name) + if name: + name += ' ' + return '<%s %s(%s)>' % (self.__class__.__name__, name, level) diff --git a/log/highlighter.py b/log/highlighter.py new file mode 100644 index 0000000..eb36376 --- /dev/null +++ b/log/highlighter.py @@ -0,0 +1,10 @@ +from rich.highlighter import Highlighter + +__all__ = [] + +class ColorHighlighter(Highlighter): + def __init__(self, color='black'): + self.color = color + + def highlight(self, text): + text.stylize(self.color) \ No newline at end of file diff --git a/log/logger.py b/log/logger.py new file mode 100644 index 0000000..b7a710a --- /dev/null +++ b/log/logger.py @@ -0,0 +1,376 @@ +r""" +:class:`Logger` 是记录日志的模块,**logger** 封装了 logging 模块的 Logger, +具体使用方式与直接使用 :class:`logging.Logger` 相同,同时也新增一些简单好用的API + +使用方式:: + + # logger 可以和 logging.Logger 一样使用 + logger.info('your msg') + logger.error('your msg') + + # logger 新增的API + # 将日志输出到文件,以及输出的日志等级 + logger.add_file('/path/to/log', level='INFO') + # 定义在命令行中的显示格式和日志等级 + logger.set_stdout('tqdm', level='WARN') + # 仅警告一次 + logger.warning_once('your msg') + # 分布式训练下,仅在 rank 0 输出警告 + logger.rank_zero_warning('your msg') + +""" + + +import logging +import logging.config +from logging import DEBUG, ERROR, INFO, WARNING, CRITICAL, raiseExceptions +import os +import sys +import warnings +from pathlib import Path +from typing import Optional, Union +from rich.logging import RichHandler +import datetime +import torch + +__all__ = [ + 'logger' +] + +from .handler import StdoutStreamHandler, TqdmLoggingHandler + + +ROOT_NAME = 'LOMO' + + +class LoggerSingleton(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(LoggerSingleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +class LOMOLogger(logging.Logger, metaclass=LoggerSingleton): + def __init__(self, name): + super().__init__(name) + self._warning_msgs = set() + + def add_file(self, path: Optional[Union[str, Path]] = None, level='AUTO', remove_other_handlers: bool = False, + mode: str = "w"): + """ + 将日志输出到 path 中。 + + :param path: 若 path 为文件路径(通过 path 是否包含后缀判定 path 是否表示文件名,例如 output.log 会被认为是文件,而 + output 则认为是文件夹)则直接写入到给定文件中;如果判定为文件夹,则是在该文件夹下以 时间戳 创建一个日志文件。 + :param level: 可选 ['INFO', 'WARNING', 'DEBUG', 'ERROR', 'AUTO'], 其中AUTO表示根据环境变量"LOMO_LOG_LEVEL'进行 + 设置。 + :param remove_other_handlers: 是否移除其它 handler ,如果移除,则terminal中将不会有 log 输出。 + :param mode: 可选为['w', 'a'],如果传入的 path 是存在的文件,'w' 会覆盖原有内容 'a' 则会在文件结尾处继续添加。 + :return: + """ + r"""添加日志输出文件和输出级别""" + if level == 'AUTO': + level = parse_level() + return _add_file_handler(self, path, level, remove_other_handlers, mode) + + def set_stdout(self, stdout: str = 'raw', level: str = 'AUTO'): + """ + 设置 log 的 terminal 输出形式。 + + :param stdout: 可选['rich', 'naive', 'raw', 'none']。 + :param level: 可选 ['INFO', 'WARNING', 'DEBUG', 'ERROR', 'AUTO'], 其中AUTO表示根据环境变量"LOMO_LOG_LEVEL'进行 + 设置。 + :return: + """ + r"""设置标准输出格式和输出级别""" + if level == 'AUTO': + level = parse_level() + return _set_stdout_handler(self, stdout, level) + + def debug(self, msg, *args, **kwargs): + """ + Delegate a debug call to the underlying log. + """ + if self.isEnabledFor(DEBUG): + kwargs = self._add_rank_info(kwargs) + self._log(DEBUG, msg, args, **kwargs) + + def info(self, msg, *args, **kwargs): + """ + Delegate an info call to the underlying log. + """ + if self.isEnabledFor(INFO): + kwargs = self._add_rank_info(kwargs) + self._log(INFO, msg, args, **kwargs) + + def warning(self, msg, *args, **kwargs): + """ + Delegate a warning call to the underlying log. + """ + if self.isEnabledFor(WARNING): + kwargs = self._add_rank_info(kwargs) + self._log(WARNING, msg, args, **kwargs) + + def warning_once(self, msg, *args, **kwargs): + """ + 相同的 warning 内容只会 warning 一次 + + :param msg: + :param args: + :param kwargs: + :return: + """ + if msg not in self._warning_msgs: + if self.isEnabledFor(WARNING): + kwargs = self._add_rank_info(kwargs) + self._log(WARNING, msg, args, **kwargs) + self._warning_msgs.add(msg) + + def rank_zero_warning(self, msg, *args, once=False, **kwargs): + """ + 只在 rank 0 上 warning 。 + + :param msg: + :param args: + :param once: 是否只 warning 一次 + :param kwargs: + :return: + """ + if os.environ.get('LOCAL_RANK', 0) == 0: + if once: + if msg in self._warning_msgs: + return + self._warning_msgs.add(msg) + + if self.isEnabledFor(WARNING): + kwargs = self._add_rank_info(kwargs) + self._log(WARNING, msg, args, **kwargs) + + def warn(self, msg, *args, **kwargs): + if self.isEnabledFor(WARNING): + kwargs = self._add_rank_info(kwargs) + self._log(WARNING, msg, args, **kwargs) + + def error(self, msg, *args, **kwargs): + """ + Delegate an error call to the underlying log. + """ + if self.isEnabledFor(ERROR): + kwargs = self._add_rank_info(kwargs) + self._log(ERROR, msg, args, **kwargs) + + def exception(self, msg, *args, exc_info=True, **kwargs): + """ + Delegate an exception call to the underlying log. + """ + kwargs = self._add_rank_info(kwargs) + self.error(msg, *args, exc_info=exc_info, **kwargs) + + def critical(self, msg, *args, **kwargs): + """ + Delegate a critical call to the underlying log. + """ + if self.isEnabledFor(CRITICAL): + kwargs = self._add_rank_info(kwargs) + self._log(CRITICAL, msg, args, **kwargs) + + def log(self, level, msg, *args, **kwargs): + """ + Delegate a log call to the underlying log, after adding + contextual information from this adapter instance. + """ + if not isinstance(level, int): + if raiseExceptions: + raise TypeError("level must be an integer") + else: + return + if self.isEnabledFor(level): + kwargs = self._add_rank_info(kwargs) + self._log(level, msg, args, **kwargs) + + def _add_rank_info(self, kwargs): + if torch.distributed.is_initialized(): + extra = kwargs.get('extra', {}) + extra.update({"rank": int(os.environ.get('LOCAL_RANK', 0))}) + kwargs["extra"] = extra + return kwargs + + def setLevel(self, level) -> None: + """ + 设置当前 logger 以及其 handler 的 log 级别 + + :param level: + :return: + """ + if isinstance(level, str): + level = level.upper() + super().setLevel(level) + for handler in self.handlers: + handler.setLevel(level) + + def _set_distributed(self): + """ + 在 LOMO 拉起进程的时候,调用一下这个方法,使得能够输出 rank 信息 + + :return: + """ + for handler in self.handlers: + if isinstance(handler, logging.FileHandler): + formatter = logging.Formatter(fmt='Rank: %(rank)s - %(asctime)s - %(module)s - [%(levelname)s] - %(message)s', + datefmt='%Y/%m/%d %H:%M:%S') + else: + formatter = logging.Formatter('Rank: %(rank)s - %(message)s') + handler.setFormatter(formatter) + + +def _get_level(level): + if not isinstance(level, int): + level = level.lower() + level = {'info': logging.INFO, 'debug': logging.DEBUG, + 'warn': logging.WARN, 'warning': logging.WARNING, + 'error': logging.ERROR}[level] + return level + + +def _add_file_handler(_logger: logging.Logger, path: Optional[Union[str, Path]] = None, level: str = 'INFO', + remove_other_handlers: bool = False, mode: str = "w"): + if path is None: + path = Path.cwd() + if isinstance(path, str): + path = Path(path) + if not isinstance(path, Path): + raise TypeError("Parameter `path` can only be `str` or `pathlib.Path` type.") + if not path.exists(): + head, tail = os.path.splitext(path) + if tail == '': # 说明没有后缀,理解为是一个folder + path.mkdir(parents=True, exist_ok=True) + else: + # 主进程会帮助我们创建文件夹,但是由于主从进程几乎是同步的,因此到这里时子进程也会尝试创建文件夹,即使主进程会做这件事情; + dirname = os.path.dirname(path) + os.makedirs(dirname, exist_ok=True) + if path.is_dir(): + path = path.joinpath(os.environ.get('LOGGING_TIME', f"{datetime.datetime.now().strftime('%Y-%m-%d-%H_%M_%S_%f')}") + '.log') + + if not isinstance(remove_other_handlers, bool): + raise TypeError("Parameter `remove_other_handlers` can only be `bool` type.") + + if not isinstance(mode, str): + raise TypeError("Parameter 'evaluate_fn' can only be `str` type.") + if mode not in {"w", "a"}: + raise ValueError("Parameter `evaluate_fn` can only be one of these values: ('w', 'a').") + + for h in _logger.handlers: + if isinstance(h, logging.FileHandler): + if os.path.abspath(path) == h.baseFilename: + # file path already added + return + + # File Handler + if int(os.environ.get('LOCAL_RANK', 0)) == 0: + if os.path.exists(path): + assert os.path.isfile(path) + warnings.warn('log already exists in {}'.format(path)) + + dirname = os.path.abspath(os.path.dirname(path)) + os.makedirs(dirname, exist_ok=True) + + # 这里只要检测到是分布式训练,我们就将 evaluate_fn 改为 "a";这样会导致的一个问题在于,如果第二次训练也是分布式训练,logger记录的log不会重新 + # 覆盖掉原文件,而是会接着上一次的 log 继续添加; + # 这样做主要是为了解决这样的情形所导致的问题:在分布式训练中,进程 1 比 进程 0 先运行到这里,然后使得进程 0 将进程 1 的 log 覆盖掉; + # if torch.distributed.is_initialized():# and int(os.environ.get(LOMO_GLOBAL_RANK, 0)) != 0: + # mode = "a" + + file_handler = logging.FileHandler(path, mode=mode) + logger.info(f"Writing log to file:{os.path.abspath(path)}") + file_handler.setLevel(_get_level(level)) + + if torch.distributed.is_initialized(): + file_formatter = logging.Formatter(fmt='Rank: %(rank)s - %(asctime)s - %(module)s - [%(levelname)s] - %(message)s', + datefmt='%Y/%m/%d %H:%M:%S') + else: + file_formatter = logging.Formatter(fmt='%(asctime)s - %(module)s - [%(levelname)s] - %(message)s', + datefmt='%Y/%m/%d %H:%M:%S') + + file_handler.setFormatter(file_formatter) + _logger.addHandler(file_handler) + + if remove_other_handlers: + _need_remove_handlers = [] + for i, h in enumerate(_logger.handlers): + if not isinstance(h, logging.FileHandler): + _need_remove_handlers.append(h) + for handler in _need_remove_handlers: + _logger.removeHandler(handler) + + return file_handler + + +def _set_stdout_handler(_logger, stdout='raw', level='INFO'): + level = _get_level(level) + supported_stdout = ['none', 'raw', 'tqdm', 'naive', 'rich'] + if stdout not in supported_stdout: + raise ValueError('stdout must in one of {}'.format(supported_stdout)) + # make sure to initialize _logger only once + stream_handler = None + _handlers = (logging.StreamHandler, TqdmLoggingHandler, StdoutStreamHandler, RichHandler) + for i, h in enumerate(_logger.handlers): + if isinstance(h, _handlers): + stream_handler = h + break + if stream_handler is not None: + _logger.removeHandler(stream_handler) + del stream_handler + + # Stream Handler + if stdout == 'raw': + stream_handler = StdoutStreamHandler() + elif stdout == 'rich': + stream_handler = RichHandler(level=level, log_time_format="[%X]") + elif stdout == 'naive': + stream_handler = logging.StreamHandler(sys.stdout) + elif stdout == 'tqdm': + stream_handler = TqdmLoggingHandler(level) + else: + stream_handler = None + + if stream_handler is not None: + if torch.distributed.is_initialized(): + stream_formatter = logging.Formatter('Rank: %(rank)s - %(message)s') + else: + stream_formatter = logging.Formatter('%(message)s') + stream_handler.setLevel(level) + stream_handler.setFormatter(stream_formatter) + _logger.addHandler(stream_handler) + + return stream_handler + + +def _init_logger(path=None, stdout='rich', level='INFO'): + r"""initialize _logger""" + level = _get_level(level) + + logger = LOMOLogger(ROOT_NAME) + + logger.propagate = False + + _set_stdout_handler(logger, stdout, level) + + # File Handler + if path is not None: + _add_file_handler(logger, path, level) + + logger.setLevel(level) + + return logger + + +def parse_level(): + level = 'WARNING' if int(os.environ.get('LOCAL_RANK', 0)) != 0 else "INFO" + return level + + +logger = _init_logger(path=None, stdout='rich', level=parse_level()) +logger.debug("The environment variables are as following:") +logger.debug(os.environ) diff --git a/log/print.py b/log/print.py new file mode 100644 index 0000000..40dc7b5 --- /dev/null +++ b/log/print.py @@ -0,0 +1,22 @@ +__all__ = [ + 'print' +] +from logging import INFO +from .logger import logger + + +def print(*args, sep=' ', end='\n', file=None, flush=False): + """ + 用来重定向 print 函数至 logger.info 的函数。 + + :param args: 需要打印的内容 + :param sep: 存在多个输入时,使用的间隔。 + :param end: 该参数在当前设置无意义,因为结尾一定会被加入 ``'\\\\n'`` 。 + :param file: 该参数无意义。 + :param flush: 该参数无意义。 + :return: + """ + line = sep.join(map(str, args)) + if logger.isEnabledFor(INFO): + kwargs = logger._add_rank_info({}) + logger._log(INFO, line, None, **kwargs) diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..7bbd573 --- /dev/null +++ b/run.sh @@ -0,0 +1,8 @@ +set -x +port=$(shuf -i25000-30000 -n1) + +# for full parameter fine-tuning using LOMO +deepspeed --master_port "$port" --include localhost:0,1,2,3,4,5,6,7 src/train_zero.py config/hf_args_zero.yaml + +# for LoRA + LOMO +#deepspeed --master_port "$port" --include localhost:0 src/train_zero_lora.py config/hf_args_zero_lora.yaml diff --git a/src/arguments.py b/src/arguments.py new file mode 100644 index 0000000..f5ae710 --- /dev/null +++ b/src/arguments.py @@ -0,0 +1,69 @@ +from dataclasses import dataclass, field +from typing import Optional +from transformers import Seq2SeqTrainingArguments + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="llama-7B") + cache_dir: Optional[str] = field(default='../llama/checkpoint') + # llama_dir: Optional[str] = field(default='/remote-home/klv/exps/MossOn3090/llama') + + +@dataclass +class DataArguments: + data_dir: str = field(default='data') + dataset_name: str = field(default='openbookqa') + refresh: bool = field(default=False, metadata={"help": "Whether to refresh the data."}) + + data_tag: str = field(default='src') + prompt_type: str = field(default='natural', metadata={"help": "The type of prompt, including [natural, brown]."}) + train_on_inputs: bool = field(default=False, metadata={"help": "Whether to train on input."}) + data_max_length: int = field(default=1024) + few_shot_size: int = field(default=-1) + in_context_learning: bool = field(default=False, metadata={"help": "Whether to use in-context learning."}) + + +@dataclass +class MyTrainingArguments(Seq2SeqTrainingArguments): + tag: str = field(default=None, metadata={"help": "Tag for the experiment."}) + + clip_grad_norm: float = field(default=None, metadata={ + "help": "Maximum gradient normalized value (for gradient clipping)."}) # recommend 1.0 + clip_grad_value: float = field(default=None, metadata={"help": "Maximum gradient value (for gradient clipping)."}) + clip_loss_value: float = field(default=None, + metadata={"help": "Maximum loss value (for token loss clipping)."}) # recommend 5.0 + warmup: float = field(default=0.0, + metadata={"help": "The number of warmup steps (int) or the warmup ratio (float)."}) + + max_length: int = field(default=20, metadata={"help": "The maximum length of the sequence to be generated."}) + max_new_tokens: int = field(default=None, metadata={ + "help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}) + do_sample: bool = field(default=False, + metadata={"help": "Whether or not to use sampling ; use greedy decoding otherwise."}) + temperature: float = field(default=1.0, + metadata={"help": "The value used to modulate the next token probabilities."}) + top_k: int = field(default=50, metadata={ + "help": "If set to int > 0, only the top k tokens with the highest probability will be considered for generation."}) + top_p: float = field(default=1.0, metadata={ + "help": "If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation."}) + typical_p: float = field(default=1.0, metadata={ + "help": "If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation."}) + repetition_penalty: float = field(default=1.0, metadata={ + "help": "The parameter for repetition penalty. 1.0 means no penalty. See this paper for more details: https://arxiv.org/pdf/1909.05858.pdf"}) + + length_normalization: bool = field(default=True, metadata={"help": "Whether to normalize the loss by the length of the input."}) + unconditional_normalization: bool = field(default=False, metadata={"help": "Whether to normalize the loss by the length of the input."}) + + hf_learning_rate: float = field(default=5e-4, metadata={"help": "The learning rate for the HF optimizer."}) + hf_weight_decay: float = field(default=0.0, metadata={"help": "The weight decay for the HF optimizer."}) + hf_lr_scheduler_type: str = field(default='linear', metadata={"help": "The lr scheduler type for the HF optimizer."}) + hf_warmup: int = field(default=0, metadata={"help": "The warmup steps for the HF optimizer."}) + + # lora hyperparams + peft_type: str = field(default=None, metadata={ + "help": "The type of PEFT, including [lora, prefix-tuning, prompt-tuning, p-tuning]."}) + lora_r: int = field(default=8, metadata={"help": "Lora attention dimension."}) + lora_alpha: int = field(default=16, metadata={"help": "The alpha parameter for Lora scaling."}) + lora_dropout: float = field(default=0.05, metadata={"help": "The dropout probability for Lora layers."}) + lora_only: bool = field(default=False, metadata={"help": "Whether to use LoRA without inplace SGD"}) diff --git a/src/inplace_zero_trainer.py b/src/inplace_zero_trainer.py new file mode 100644 index 0000000..df6764e --- /dev/null +++ b/src/inplace_zero_trainer.py @@ -0,0 +1,516 @@ +import os +import sys +import operator +from collections import OrderedDict +from itertools import chain +from pathlib import Path +import shutil + +import tqdm +import numpy as np +import torch +from torch.nn import CrossEntropyLoss +from torch.utils.data import DistributedSampler, DataLoader +from transformers.trainer_pt_utils import DistributedLengthGroupedSampler, SequentialDistributedSampler +from transformers.trainer_utils import has_length, seed_worker +from transformers import GenerationConfig + +try: + import deepspeed + from deepspeed import comm as dist + from deepspeed.accelerator import get_accelerator +except: + pass + +from src.utils import LearningRateScheduler, WandbLogger, DynamicLossScaler +from log import print + + +class InplaceZeroTrainer: + def __init__( + self, + model, + training_args, + data_collator, + train_dataset, + eval_dataset, + tokenizer, + compute_metrics, + ): + self.training_args = training_args + self.train_dataset = train_dataset + self.eval_dataset = eval_dataset + self.tokenizer = tokenizer + self.wandb = WandbLogger(training_args) + self.allow_print = self.training_args.local_rank in [-1, 0] + if self.training_args.do_eval: + self.metrics = {} + self.compute_metrics = compute_metrics + + if 'deepspeed' not in sys.modules: + raise ModuleNotFoundError( + "Detected DeepSpeed is not installed. See https://github.com/microsoft/DeepSpeed") + + # Initialize deepspeed engine + self.model, _, _, _ = deepspeed.initialize( + config=training_args.deepspeed, + model=model, + ) + + # get train_dataloader and eval_dataloader + if isinstance(data_collator, dict): + assert 'train' in data_collator and 'eval' in data_collator, "data_collator should be a dict with keys 'train' and 'eval'." + self.train_data_collator = data_collator['train'] + if self.training_args.do_eval: + self.eval_data_collator = data_collator['eval'] + else: + self.train_data_collator = self.eval_data_collator = data_collator + self.train_dataloader = self.get_train_dataloader() + if self.training_args.do_eval: + if isinstance(self.eval_dataset, dict): + self.eval_dataloader = {} + for prefix in self.eval_dataset.keys(): + self.eval_dataloader[prefix] = self.get_eval_dataloader(self.eval_dataset[prefix]) + else: + self.eval_dataloader = self.get_eval_dataloader() + + # setup learning rate + self.num_steps_per_epoch = len(self.train_dataloader) + self.global_step = 1 + self.n_steps = self.num_steps_per_epoch * self.training_args.num_train_epochs + self.lr_scheduler = LearningRateScheduler(learning_rate=self.training_args.learning_rate, + warmup=self.training_args.warmup, + schedule=self.training_args.lr_scheduler_type, + n_steps=self.n_steps) + self.lr = 0 + # for grad norm + self.gather_norm = False + self.grad_norms = [] + self.clip_coef = None + # register inplace grad hook + self.grad_func = self.inplace_grad() + for n, p in model.named_parameters(): + if p.requires_grad: + p.register_hook(self.grad_func) + + get_accelerator().empty_cache() + + # loss scaler + self.loss_scaler = DynamicLossScaler( + init_scale=2 ** 16, + ) # TODO: add args + + def inplace_grad(self): + # An approximation of in-place grad update under zero3 of deepspeed + def func(x): + with torch.no_grad(): + for n, p in self.model.named_parameters(): + if p.grad is not None: + torch.distributed.all_reduce(p.grad, op=torch.distributed.ReduceOp.AVG, async_op=False) + if self.loss_scaler.has_overflow_serial or self.loss_scaler._has_inf_or_nan(p.grad): + p.grad = None + self.loss_scaler.has_overflow_serial = True + break + # p.grad.div_(self.loss_scaler.loss_scale) + if self.gather_norm: + grad_fp32 = p.grad.detach().clone().to(torch.float32) + grad_fp32.div_(self.loss_scaler.loss_scale) + self.grad_norms.append(torch.norm(grad_fp32, 2.0)) + p.grad = None + else: + one_dim_grad = p.grad.view(-1) + partition_size = p.ds_tensor.numel() + start = partition_size * self.training_args.local_rank + end = start + partition_size + + if end > p.grad.numel(): + partitioned_grad = one_dim_grad.narrow(0, start, p.grad.numel() - start) + # partitioned_grad = torch.cat([partitioned_grad, torch.zeros(end - p.grad.numel()).cuda()]) + partitioned_p = p.ds_tensor.narrow(0, 0, p.grad.numel() - start) + partitioned_grad_fp32 = partitioned_grad.detach().clone().to(torch.float32) + partitioned_grad_fp32.div_(self.loss_scaler.loss_scale) + partitioned_p_fp32 = partitioned_p.detach().clone().to(torch.float32) + if self.training_args.clip_grad_value is not None: + # Gradients are modified in-place. + partitioned_grad_fp32.clamp_(min=-self.training_args.clip_grad_value, + max=self.training_args.clip_grad_value) + if self.training_args.clip_grad_norm is not None and self.training_args.clip_grad_norm > 0 and self.clip_coef is not None: + partitioned_grad_fp32.mul_(self.clip_coef) + partitioned_p_fp32.add_(partitioned_grad_fp32, alpha=-self.lr) + partitioned_p.copy_(partitioned_p_fp32) + else: + partitioned_grad = one_dim_grad.narrow(0, start, partition_size) + partitioned_grad_fp32 = partitioned_grad.detach().clone().to(torch.float32) + partitioned_grad_fp32.div_(self.loss_scaler.loss_scale) + if self.training_args.clip_grad_value is not None: + # Gradients are modified in-place. + partitioned_grad_fp32.clamp_(min=-self.training_args.clip_grad_value, + max=self.training_args.clip_grad_value) + if self.training_args.clip_grad_norm is not None and self.training_args.clip_grad_norm > 0 and self.clip_coef is not None: + partitioned_grad_fp32.mul_(self.clip_coef) + ds_tensor_fp32 = p.ds_tensor.detach().clone().to(torch.float32) + ds_tensor_fp32.add_(partitioned_grad_fp32, alpha=-self.lr) + p.ds_tensor.copy_(ds_tensor_fp32) + p.grad = None + return x + + return func + + def train(self): + for epoch in range(self.training_args.num_train_epochs): + print(f"***** Running Training *****") + print(f" Num examples: {len(self.train_dataset)}") + print(f" Num Epochs: {self.training_args.num_train_epochs}") + print(f" Current Epoch: {epoch}") + print(f" Batch Size: {self.training_args.per_device_train_batch_size}") + if self.allow_print: + self.wandb.log({'train/epoch': epoch}, step=self.global_step) + self.train_dataloader.sampler.set_epoch(epoch) + + with tqdm.tqdm(self.train_dataloader, disable=not self.allow_print) as tqb: + for step, batch in enumerate(tqb, start=1): + self.model.train() + outs = self.model( + input_ids=batch['input_ids'].cuda(), + attention_mask=batch['attention_mask'].cuda(), + ) + # Shift so that tokens < n predict n + shift_logits = outs.logits[..., :-1, :].contiguous() + shift_labels = batch['labels'][:, 1:].contiguous() + # Flatten the tokens + if self.training_args.clip_loss_value is not None: + loss_fct = CrossEntropyLoss(reduction='none') + loss = loss_fct(shift_logits.view(shift_labels.shape[0] * shift_labels.shape[1], -1), + shift_labels.view(-1).cuda()) + loss.data.clamp_(min=-self.training_args.clip_loss_value, max=self.training_args.clip_loss_value) + loss = loss.mean() + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(shift_labels.shape[0] * shift_labels.shape[1], -1), + shift_labels.view(-1).cuda()) + + # update the learning rate + self.global_step = self.num_steps_per_epoch * epoch + step + self.lr = self.lr_scheduler.step(self.global_step) + if self.training_args.clip_grad_norm is not None and self.training_args.clip_grad_norm > 0: + self.gather_norm = True + self.grad_norms = [] + self.loss_scaler.has_overflow_serial = False + scaled_loss = loss * self.loss_scaler.loss_scale + + scaled_loss.backward() + # update the last one since the hook function will not be called for the last parameter + self.grad_func(0) + + if self.loss_scaler.has_overflow_serial: + print(f"Gradient overflow, skipping step {self.global_step}") + self.loss_scaler.update_scale(overflow=True) + with torch.no_grad(): + for n, p in self.model.named_parameters(): + p.grad = None + self.model.optimizer.get_param_coordinator(training=True).reset_step() + tqb.set_postfix({'loss': loss.item()}) + if self.allow_print: + self.wandb.log( + { + 'train/loss': loss.item(), + 'train/learning_rate': self.lr, + 'train/global_step': self.global_step, + }, + step=self.global_step + ) + continue + + with torch.no_grad(): + # The norm is computed over all gradients together, as if they were + # concatenated into a single vector. Gradients are modified in-place. + self.grad_norms = torch.stack(self.grad_norms) + # device = torch.device(f"cuda:{self.training_args.local_rank}") + # all_grad_norms = torch.zeros(self.training_args.world_size * self.grad_norms.shape[0], dtype=self.grad_norms.dtype, device=device) + # torch.distributed.all_gather_into_tensor(all_grad_norms, self.grad_norms) + + # total_norm = torch.norm(all_grad_norms, 2.0) / self.training_args.world_size + total_norm = torch.norm(self.grad_norms, 2.0) + self.clip_coef = float(self.training_args.clip_grad_norm) / (total_norm + 1e-6) + self.clip_coef = torch.clamp(self.clip_coef, max=1.0) + self.gather_norm = False + + self.model.optimizer.get_param_coordinator(training=True).reset_step() + # 第二次forward + outs = self.model( + input_ids=batch['input_ids'].cuda(), + attention_mask=batch['attention_mask'].cuda(), + ) + # Shift so that tokens < n predict n + shift_logits = outs.logits[..., :-1, :].contiguous() + shift_labels = batch['labels'][:, 1:].contiguous() + # Flatten the tokens + if self.training_args.clip_loss_value is not None: + loss_fct = CrossEntropyLoss(reduction='none') + loss = loss_fct(shift_logits.view(shift_labels.shape[0] * shift_labels.shape[1], -1), + shift_labels.view(-1).cuda()) + loss.data.clamp_(min=-self.training_args.clip_loss_value, + max=self.training_args.clip_loss_value) + loss = loss.mean() + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(shift_labels.shape[0] * shift_labels.shape[1], -1), + shift_labels.view(-1).cuda()) + + scaled_loss = loss * self.loss_scaler.loss_scale + + scaled_loss.backward() + # update the last one since the hook function will not be called for the last parameter + self.grad_func(0) + self.loss_scaler.update_scale(overflow=False) + self.model.optimizer.get_param_coordinator(training=True).reset_step() + + tqb.set_postfix({'loss': loss.item()}) + if self.allow_print: + self.wandb.log( + { + 'train/loss': loss.item(), + 'train/learning_rate': self.lr, + 'train/global_step': self.global_step, + }, + step=self.global_step + ) + + if self.training_args.save_strategy == 'steps' and self.global_step % self.training_args.save_steps == 0: + self.save_model(self.global_step) + + if self.training_args.do_eval and self.training_args.evaluation_strategy == 'steps' and \ + self.global_step % self.training_args.eval_steps == 0: + if isinstance(self.eval_dataset, dict): + for prefix in self.eval_dataset.keys(): + assert prefix in self.eval_dataloader.keys(), "eval_dataset and eval_dataloader should have the same keys." + self.eval(self.global_step, epoch, self.eval_dataset[prefix], + self.eval_dataloader[prefix], prefix) + else: + self.eval(self.global_step, epoch, self.eval_dataset, self.eval_dataloader, 'eval') + + if self.training_args.save_strategy == 'epoch': + self.save_model(epoch) + + if self.training_args.do_eval and self.training_args.evaluation_strategy == 'epoch': + if isinstance(self.eval_dataset, dict): + for prefix in self.eval_dataset.keys(): + assert prefix in self.eval_dataloader.keys(), "eval_dataset and eval_dataloader should have the same keys." + self.eval(self.global_step, epoch, self.eval_dataset[prefix], self.eval_dataloader[prefix], + prefix) + else: + self.eval(self.global_step, epoch, self.eval_dataset, self.eval_dataloader, 'eval') + + def eval( + self, + step: int, + epoch: int, + dataset: torch.utils.data.Dataset, + dataloader: DataLoader, + eval_prefix: str + ): + r""" + Shared by both eval(validation) and predict(test). + This method will be called by the trainer to evaluate the model. + """ + print(f"***** Running {eval_prefix} *****") + print(f" Num examples: {len(dataset)}") + print(f" Current Epoch: {epoch}") + print(f" Batch size: {self.training_args.per_device_eval_batch_size}") + + with tqdm.tqdm(dataloader, disable=not self.allow_print) as tqb: + all_preds = None + self.model.eval() + for batch in tqb: + with torch.no_grad(): + pred = self.eval_step(batch) + all_preds = pred if all_preds is None else all_preds + pred + + all_preds_gather = [None for _ in range(self.training_args.world_size)] + torch.distributed.all_gather_object(all_preds_gather, all_preds) + all_pred_merged = list(chain(*all_preds_gather)) + + result = self.compute_metrics(all_pred_merged, dataset, eval_prefix) + result = {f"{eval_prefix}/{k}": v for k, v in result.items()} + prefix_metric_for_best_model = f'{eval_prefix}/{self.training_args.metric_for_best_model}' + result_value = result[prefix_metric_for_best_model] + + if self.allow_print: + print(f'epoch: {epoch}, step: {step}, {self.training_args.metric_for_best_model}: {result_value}') + self.wandb.log(result, step=step) + + if self.is_better(result, prefix_metric_for_best_model): + self.wandb.set_summary(f'{eval_prefix}/best_{self.training_args.metric_for_best_model}', result_value) + self.wandb.set_summary(f'{eval_prefix}/best_epoch', epoch) + self.wandb.set_summary(f'{eval_prefix}/best_step', step) + self.metrics[prefix_metric_for_best_model] = result_value + + def eval_step(self, batch): + self.model.eval() + generation_config = GenerationConfig(max_length=self.training_args.max_length, + max_new_tokens=self.training_args.max_new_tokens, + do_sample=self.training_args.do_sample, + temperature=self.training_args.temperature, + top_k=self.training_args.top_k, + top_p=self.training_args.top_p, + typical_p=self.training_args.typical_p, + repetition_penalty=self.training_args.repetition_penalty, ) + logits = self.model.generate( + inputs=batch['input_ids'].cuda(), + generation_config=generation_config + ) + predictions = logits.detach().cpu().numpy() + pred_texts = self.tokenizer.batch_decode(predictions, skip_special_tokens=True) + return pred_texts + + def is_better(self, result_dict, key): + """ + 判断 ``result`` 是否更好。 + + :param result: + """ + op = operator.gt if self.training_args.greater_is_better else operator.lt + return ( + key not in self.metrics or op(result_dict[key], self.metrics[key]) + ) + + def get_train_sampler(self): + if self.train_dataset is None or not has_length(self.train_dataset): + return None + + # for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with + # `self.training_args.seed`) if data_seed isn't provided. + # Further on in this method, we default to `self.training_args.seed` instead. + seed = self.training_args.data_seed if self.training_args.data_seed is not None else self.training_args.seed + + if self.training_args.group_by_length: + return DistributedLengthGroupedSampler( + self.training_args.per_device_train_batch_size * self.training_args.gradient_accumulation_steps, + dataset=self.train_dataset, + num_replicas=self.training_args.world_size, + rank=self.training_args.local_rank, + lengths=None, + model_input_name="input_ids", + seed=seed, + ) + else: + return DistributedSampler( + self.train_dataset, + num_replicas=self.training_args.world_size, + rank=self.training_args.local_rank, + seed=seed + ) + + def get_train_dataloader(self): + """ + Returns the training [`~torch.utils.data.DataLoader`]. + Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed + training if necessary) otherwise. + Subclass and override this method if you want to inject some custom behavior. + """ + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + data_collator = self.train_data_collator + train_sampler = self.get_train_sampler() + + return DataLoader( + self.train_dataset, + batch_size=self.training_args.per_device_train_batch_size, + sampler=train_sampler, + collate_fn=data_collator, + drop_last=self.training_args.dataloader_drop_last, + num_workers=self.training_args.dataloader_num_workers, + pin_memory=self.training_args.dataloader_pin_memory, + worker_init_fn=seed_worker, + ) + + def get_eval_sampler(self, eval_dataset): + return SequentialDistributedSampler( + eval_dataset, + num_replicas=self.training_args.world_size, + rank=self.training_args.local_rank, + # batch_size=self.training_args.per_device_eval_batch_size + ) + + def get_eval_dataloader(self, eval_dataset=None): + """ + Returns the evaluation [`~torch.utils.data.DataLoader`]. + + Subclass and override this method if you want to inject some custom behavior. + + Args: + eval_dataset (`torch.utils.data.Dataset`, *optional*): + If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted + by the `model.forward()` method are automatically removed. It must implement `__len__`. + """ + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + data_collator = self.eval_data_collator + + eval_sampler = self.get_eval_sampler(eval_dataset) + + return DataLoader( + eval_dataset, + sampler=eval_sampler, + batch_size=self.training_args.per_device_eval_batch_size, + collate_fn=data_collator, + drop_last=self.training_args.dataloader_drop_last, + num_workers=self.training_args.dataloader_num_workers, + pin_memory=self.training_args.dataloader_pin_memory, + ) + + def save_model(self, index): + if self.training_args.local_rank in [-1, 0]: + checkpoint_dir = sorted(Path(self.training_args.output_dir).glob("checkpoint-*")) + if len(checkpoint_dir) >= self.training_args.save_total_limit: + shutil.rmtree(checkpoint_dir[0], ignore_errors=True) + torch.distributed.barrier() + + output_dir = os.path.join(self.training_args.output_dir, f"checkpoint-{index}") + if not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + state_dict = OrderedDict() + for n, p in self.model.module.named_parameters(): + state_dict[n] = (p.ds_tensor.detach().cpu(), p.ds_numel, p.ds_shape) + # save model shards + if self.training_args.local_rank not in [-1, 0]: + with open(os.path.join(output_dir, f'pytorch_model-{self.training_args.local_rank}.bin'), 'wb') as f: + torch.save(state_dict, f) + torch.distributed.barrier() + # merge model shards + if self.training_args.local_rank in [-1, 0]: + # save config + self.model.module.config.save_pretrained(output_dir) + for rank in range(1, self.training_args.world_size): + with open(os.path.join(output_dir, f'pytorch_model-{rank}.bin'), 'rb') as f: + state_dict_rank = torch.load(f) + for n in state_dict_rank: + state_dict[n] = ( + torch.cat([state_dict[n][0], state_dict_rank[n][0]], dim=0), + state_dict[n][1], + state_dict[n][2] + ) + # remove shard files + os.remove(os.path.join(output_dir, f'pytorch_model-{rank}.bin')) + # reshape to original shape + for n in state_dict: + numel = state_dict[n][1] + shape = state_dict[n][2] + state_dict[n] = state_dict[n][0][:numel].view(shape) + + # save inv_freq for llama + if self.model.module.config.model_type == "llama": + num_layers = self.model.module.config.num_hidden_layers + head_dim = self.model.module.config.hidden_size // self.model.module.config.num_attention_heads + base = 10000.0 + inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) + for layer in range(num_layers): + state_dict[f'model.layers.{layer}.self_attn.rotary_emb.inv_freq'] = inv_freq + + + with open(os.path.join(output_dir, f'pytorch_model.bin'), 'wb') as f: + torch.save(state_dict, f) + print(f"Save model to {output_dir}.") + torch.distributed.barrier() diff --git a/src/inplace_zero_trainer_lora.py b/src/inplace_zero_trainer_lora.py new file mode 100644 index 0000000..7b8414d --- /dev/null +++ b/src/inplace_zero_trainer_lora.py @@ -0,0 +1,526 @@ +import os +import sys +import operator +from collections import OrderedDict +from itertools import chain +from pathlib import Path + +import tqdm +import numpy as np +import torch +from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload +from torch.nn import CrossEntropyLoss +from torch.utils.data import DistributedSampler, DataLoader +from transformers.trainer_pt_utils import DistributedLengthGroupedSampler, SequentialDistributedSampler +from transformers.trainer_utils import has_length, seed_worker +from transformers import GenerationConfig +from transformers.optimization import AdamW, get_scheduler + +try: + import deepspeed + from deepspeed import comm as dist + from deepspeed.accelerator import get_accelerator +except: + pass + +from src.utils import LearningRateScheduler, WandbLogger +from log import print +from peft import get_peft_model_state_dict + + +class InplaceZeroTrainer: + def __init__( + self, + model, + training_args, + data_collator, + train_dataset, + eval_dataset, + tokenizer, + compute_metrics, + optimizers=None, + ): + self.training_args = training_args + self.train_dataset = train_dataset + self.eval_dataset = eval_dataset + self.tokenizer = tokenizer + self.wandb = WandbLogger(training_args) + self.allow_print = self.training_args.local_rank in [0, -1] + if self.training_args.do_eval: + self.metrics = {} + self.compute_metrics = compute_metrics + + # get train_dataloader and eval_dataloader + if isinstance(data_collator, dict): + assert 'train' in data_collator and 'eval' in data_collator, "data_collator should be a dict with keys 'train' and 'eval'." + self.train_data_collator = data_collator['train'] + if self.training_args.do_eval: + self.eval_data_collator = data_collator['eval'] + else: + self.train_data_collator = self.eval_data_collator = data_collator + self.train_dataloader = self.get_train_dataloader() + if self.training_args.do_eval: + if isinstance(self.eval_dataset, dict): + self.eval_dataloader = {} + for prefix in self.eval_dataset.keys(): + self.eval_dataloader[prefix] = self.get_eval_dataloader(self.eval_dataset[prefix]) + else: + self.eval_dataloader = self.get_eval_dataloader() + + # setup learning rate + self.num_steps_per_epoch = len(self.train_dataloader) + self.global_step = 1 + self.n_steps = self.num_steps_per_epoch * self.training_args.num_train_epochs + self.lr_scheduler = LearningRateScheduler(learning_rate=self.training_args.learning_rate, + warmup=self.training_args.warmup, + schedule=self.training_args.lr_scheduler_type, + n_steps=self.n_steps) + self.lr = 0 + # for grad norm + self.gather_norm = False + self.grad_norms = [] + self.clip_coef = None + + hf_optimizer = AdamW(optimizers['model_parameters'], lr=training_args.hf_learning_rate, + weight_decay=training_args.hf_weight_decay) + hf_lr_scheduler = get_scheduler(training_args.hf_lr_scheduler_type, + optimizer=hf_optimizer, + num_warmup_steps=training_args.hf_warmup * self.n_steps if training_args.hf_warmup < 1 else training_args.hf_warmup, + num_training_steps=self.n_steps) + + if 'deepspeed' not in sys.modules: + raise ModuleNotFoundError( + "Detected DeepSpeed is not installed. See https://github.com/microsoft/DeepSpeed") + + # Initialize deepspeed engine + self.model, self.peft_optimizer, _, self.peft_lr_scheduler = deepspeed.initialize( + config=training_args.deepspeed, + model=model, + model_parameters=optimizers['model_parameters'], + optimizer=hf_optimizer, + lr_scheduler=hf_lr_scheduler + ) + + if not self.training_args.lora_only: + # register inplace grad hook + self.grad_func = self.inplace_grad() + for n, p in model.named_parameters(): + if "lora_" not in n and p.requires_grad: + p.register_hook(self.grad_func) + + # self.dummy_optimizer = DeepSpeedZeRoOffload( + # self.model.module, + # timers=self.model.timers if self.model.wall_clock_breakdown() else None, + # ds_config=self.model.config, + # overlap_comm=self.model.zero_overlap_comm(), + # prefetch_bucket_size=self.model.zero_prefetch_bucket_size(), + # max_reuse_distance=self.model.zero_max_reuse_distance(), + # max_live_parameters=self.model.zero_max_live_parameters(), + # param_persistence_threshold=self.model.zero_param_persistence_threshold(), + # model_persistence_threshold=self.model.zero_model_persistence_threshold(), + # offload_param_config=self.model.zero_offload_param(), + # mpu=self.model.mpu + # ) + + get_accelerator().empty_cache() + + def inplace_grad(self): + # An approximation of in-place grad update under zero3 of deepspeed + def func(x): + with torch.no_grad(): + for n, p in self.model.named_parameters(): + if "lora_" in n: + continue + + if p.grad is not None: + torch.distributed.all_reduce(p.grad, op=torch.distributed.ReduceOp.AVG, async_op=False) + if self.gather_norm: + grad_fp32 = p.grad.detach().clone().to(torch.float32) + self.grad_norms.append(torch.norm(grad_fp32, 2.0)) + p.grad = None + else: + one_dim_grad = p.grad.view(-1) + partition_size = p.ds_tensor.numel() + start = partition_size * self.training_args.local_rank + end = start + partition_size + + if end > p.grad.numel(): + partitioned_grad = one_dim_grad.narrow(0, start, p.grad.numel() - start) + # partitioned_grad = torch.cat([partitioned_grad, torch.zeros(end - p.grad.numel()).cuda()]) + partitioned_p = p.ds_tensor.narrow(0, 0, p.grad.numel() - start) + partitioned_grad_fp32 = partitioned_grad.detach().clone().to(torch.float32) + partitioned_p_fp32 = partitioned_p.detach().clone().to(torch.float32) + if self.training_args.clip_grad_value is not None: + # Gradients are modified in-place. + partitioned_grad_fp32.clamp_(min=-self.training_args.clip_grad_value, + max=self.training_args.clip_grad_value) + if self.training_args.clip_grad_norm is not None and self.training_args.clip_grad_norm > 0 and self.clip_coef is not None: + partitioned_grad_fp32.mul_(self.clip_coef) + partitioned_p_fp32.add_(partitioned_grad_fp32, alpha=-self.lr) + partitioned_p.copy_(partitioned_p_fp32) + else: + partitioned_grad = one_dim_grad.narrow(0, start, partition_size) + partitioned_grad_fp32 = partitioned_grad.detach().clone().to(torch.float32) + if self.training_args.clip_grad_value is not None: + # Gradients are modified in-place. + partitioned_grad_fp32.clamp_(min=-self.training_args.clip_grad_value, + max=self.training_args.clip_grad_value) + if self.training_args.clip_grad_norm is not None and self.training_args.clip_grad_norm > 0 and self.clip_coef is not None: + partitioned_grad_fp32.mul_(self.clip_coef) + ds_tensor_fp32 = p.ds_tensor.detach().clone().to(torch.float32) + ds_tensor_fp32.add_(partitioned_grad_fp32, alpha=-self.lr) + p.ds_tensor.copy_(ds_tensor_fp32) + p.grad = None + return x + + return func + + def train(self): + for epoch in range(self.training_args.num_train_epochs): + print(f"***** Running Training *****") + print(f" Num examples: {len(self.train_dataset)}") + print(f" Num Epochs: {self.training_args.num_train_epochs}") + print(f" Current Epoch: {epoch}") + print(f" Batch Size: {self.training_args.per_device_train_batch_size}") + if self.allow_print: + self.wandb.log({'train/epoch': epoch}, step=self.global_step) + self.train_dataloader.sampler.set_epoch(epoch) + + with tqdm.tqdm(self.train_dataloader, disable=not self.allow_print) as tqb: + for step, batch in enumerate(tqb, start=1): + self.model.train() + outs = self.model( + input_ids=batch['input_ids'].cuda(), + attention_mask=batch['attention_mask'].cuda(), + ) + # Shift so that tokens < n predict n + shift_logits = outs.logits[..., :-1, :].contiguous() + shift_labels = batch['labels'][:, 1:].contiguous() + # Flatten the tokens + if self.training_args.clip_loss_value is not None: + loss_fct = CrossEntropyLoss(reduction='none') + loss = loss_fct(shift_logits.view(shift_labels.shape[0] * shift_labels.shape[1], -1), + shift_labels.view(-1).cuda()) + loss.data.clamp_(min=-self.training_args.clip_loss_value, max=self.training_args.clip_loss_value) + loss = loss.mean() + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(shift_labels.shape[0] * shift_labels.shape[1], -1), + shift_labels.view(-1).cuda()) + + # update the learning rate + if self.training_args.lora_only: + self.global_step = self.num_steps_per_epoch * epoch + step + + loss = self.model.backward(loss) + self.model.step() # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps + else: + self.global_step = self.num_steps_per_epoch * epoch + step + self.lr = self.lr_scheduler.step(self.global_step) + + if self.training_args.clip_grad_norm is not None and self.training_args.clip_grad_norm > 0: + self.gather_norm = True + self.grad_norms = [] + + self.model.backward(loss) + # update the last one since the hook function will not be called for the last parameter + self.grad_func(0) + # self.model.optimizer._get_param_coordinator(training=True).reset_step() + # self.dummy_optimizer.get_param_coordinator(training=True).reset_step() + + with torch.no_grad(): + # The norm is computed over all gradients together, as if they were + # concatenated into a single vector. Gradients are modified in-place. + self.grad_norms = torch.stack(self.grad_norms) + device = torch.device(f"cuda:{self.training_args.local_rank}") + all_grad_norms = torch.zeros(self.training_args.world_size * self.grad_norms.shape[0], + dtype=self.grad_norms.dtype, device=device) + torch.distributed.all_gather_into_tensor(all_grad_norms, self.grad_norms) + + total_norm = torch.norm(all_grad_norms, 2.0) + self.clip_coef = float(self.training_args.clip_grad_norm) / (total_norm + 1e-6) + self.clip_coef = torch.clamp(self.clip_coef, max=1.0) + self.gather_norm = False + + # 第二次forward + outs = self.model( + input_ids=batch['input_ids'].cuda(), + attention_mask=batch['attention_mask'].cuda(), + ) + # Shift so that tokens < n predict n + shift_logits = outs.logits[..., :-1, :].contiguous() + shift_labels = batch['labels'][:, 1:].contiguous() + # Flatten the tokens + if self.training_args.clip_loss_value is not None: + loss_fct = CrossEntropyLoss(reduction='none') + loss = loss_fct(shift_logits.view(shift_labels.shape[0] * shift_labels.shape[1], -1), + shift_labels.view(-1).cuda()) + loss.data.clamp_(min=-self.training_args.clip_loss_value, + max=self.training_args.clip_loss_value) + loss = loss.mean() + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(shift_labels.shape[0] * shift_labels.shape[1], -1), + shift_labels.view(-1).cuda()) + + # update peft params + loss = self.model.backward(loss) + self.grad_func(0) + self.model.step() # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps + + tqb.set_postfix({'loss': loss.item()}) + if self.allow_print: + self.wandb.log( + { + 'train/loss': loss.item(), + 'train/learning_rate': self.lr, + 'train/hf_learning_rate': self.model.get_lr()[0], + 'train/global_step': self.global_step, + }, + step=self.global_step + ) + + if self.training_args.save_strategy == 'steps' and self.global_step % self.training_args.save_steps == 0: + self.save_model(self.global_step) + + if self.training_args.do_eval and self.training_args.evaluation_strategy == 'steps' and \ + self.global_step % self.training_args.eval_steps == 0: + if isinstance(self.eval_dataset, dict): + for prefix in self.eval_dataset.keys(): + assert prefix in self.eval_dataloader.keys(), "eval_dataset and eval_dataloader should have the same keys." + self.eval(self.global_step, epoch, self.eval_dataset[prefix], + self.eval_dataloader[prefix], prefix) + else: + self.eval(self.global_step, epoch, self.eval_dataset, self.eval_dataloader, 'eval') + + if self.training_args.save_strategy == 'epoch': + self.save_model(epoch) + + if self.training_args.do_eval and self.training_args.evaluation_strategy == 'epoch': + if isinstance(self.eval_dataset, dict): + for prefix in self.eval_dataset.keys(): + assert prefix in self.eval_dataloader.keys(), "eval_dataset and eval_dataloader should have the same keys." + self.eval(self.global_step, epoch, self.eval_dataset[prefix], self.eval_dataloader[prefix], + prefix) + else: + self.eval(self.global_step, epoch, self.eval_dataset, self.eval_dataloader, 'eval') + + def eval( + self, + step: int, + epoch: int, + dataset: torch.utils.data.Dataset, + dataloader: DataLoader, + eval_prefix: str + ): + r""" + Shared by both eval(validation) and predict(test). + This method will be called by the trainer to evaluate the model. + """ + print(f"***** Running {eval_prefix} *****") + print(f" Num examples: {len(dataset)}") + print(f" Current Epoch: {epoch}") + print(f" Batch size: {self.training_args.per_device_eval_batch_size}") + + with tqdm.tqdm(dataloader, disable=not self.allow_print) as tqb: + all_preds = None + self.model.eval() + for batch in tqb: + with torch.no_grad(): + pred = self.eval_step(batch) + all_preds = pred if all_preds is None else all_preds + pred + + all_preds_gather = [None for _ in range(self.training_args.world_size)] + torch.distributed.all_gather_object(all_preds_gather, all_preds) + all_pred_merged = list(chain(*all_preds_gather)) + + result = self.compute_metrics(all_pred_merged, dataset, eval_prefix) + result = {f"{eval_prefix}/{k}": v for k, v in result.items()} + prefix_metric_for_best_model = f'{eval_prefix}/{self.training_args.metric_for_best_model}' + result_value = result[prefix_metric_for_best_model] + + if self.allow_print: + print(f'epoch: {epoch}, step: {step}, {self.training_args.metric_for_best_model}: {result_value}') + self.wandb.log(result, step=step) + + if self.is_better(result, prefix_metric_for_best_model): + self.wandb.set_summary(f'{eval_prefix}/best_{self.training_args.metric_for_best_model}', result_value) + self.wandb.set_summary(f'{eval_prefix}/best_epoch', epoch) + self.wandb.set_summary(f'{eval_prefix}/best_step', step) + self.metrics[prefix_metric_for_best_model] = result_value + + def eval_step(self, batch): + self.model.eval() + generation_config = GenerationConfig(max_length=self.training_args.max_length, + max_new_tokens=self.training_args.max_new_tokens, + do_sample=self.training_args.do_sample, + temperature=self.training_args.temperature, + top_k=self.training_args.top_k, + top_p=self.training_args.top_p, + typical_p=self.training_args.typical_p, + repetition_penalty=self.training_args.repetition_penalty,) + logits = self.model.generate( + inputs=batch['input_ids'].cuda(), + generation_config=generation_config + ) + predictions = logits.detach().cpu().numpy() + pred_texts = self.tokenizer.batch_decode(predictions, skip_special_tokens=True) + return pred_texts + + def is_better(self, result_dict, key): + """ + 判断 ``result`` 是否更好。 + + :param result: + """ + op = operator.gt if self.training_args.greater_is_better else operator.lt + return ( + key not in self.metrics or op(result_dict[key], self.metrics[key]) + ) + + def get_train_sampler(self): + if self.train_dataset is None or not has_length(self.train_dataset): + return None + + # for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with + # `self.training_args.seed`) if data_seed isn't provided. + # Further on in this method, we default to `self.training_args.seed` instead. + seed = self.training_args.data_seed if self.training_args.data_seed is not None else self.training_args.seed + + if self.training_args.group_by_length: + return DistributedLengthGroupedSampler( + self.training_args.per_device_train_batch_size * self.training_args.gradient_accumulation_steps, + dataset=self.train_dataset, + num_replicas=self.training_args.world_size, + rank=self.training_args.local_rank, + lengths=None, + model_input_name="input_ids", + seed=seed, + ) + else: + return DistributedSampler( + self.train_dataset, + num_replicas=self.training_args.world_size, + rank=self.training_args.local_rank, + seed=seed + ) + + def get_train_dataloader(self): + """ + Returns the training [`~torch.utils.data.DataLoader`]. + Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed + training if necessary) otherwise. + Subclass and override this method if you want to inject some custom behavior. + """ + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + data_collator = self.train_data_collator + train_sampler = self.get_train_sampler() + + return DataLoader( + self.train_dataset, + batch_size=self.training_args.per_device_train_batch_size, + sampler=train_sampler, + collate_fn=data_collator, + drop_last=self.training_args.dataloader_drop_last, + num_workers=self.training_args.dataloader_num_workers, + pin_memory=self.training_args.dataloader_pin_memory, + worker_init_fn=seed_worker, + ) + + def get_eval_sampler(self, eval_dataset): + return SequentialDistributedSampler( + eval_dataset, + num_replicas=self.training_args.world_size, + rank=self.training_args.local_rank, + # batch_size=self.training_args.per_device_eval_batch_size + ) + + def get_eval_dataloader(self, eval_dataset=None): + """ + Returns the evaluation [`~torch.utils.data.DataLoader`]. + + Subclass and override this method if you want to inject some custom behavior. + + Args: + eval_dataset (`torch.utils.data.Dataset`, *optional*): + If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted + by the `model.forward()` method are automatically removed. It must implement `__len__`. + """ + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + data_collator = self.eval_data_collator + + eval_sampler = self.get_eval_sampler(eval_dataset) + + return DataLoader( + eval_dataset, + sampler=eval_sampler, + batch_size=self.training_args.per_device_eval_batch_size, + collate_fn=data_collator, + drop_last=self.training_args.dataloader_drop_last, + num_workers=self.training_args.dataloader_num_workers, + pin_memory=self.training_args.dataloader_pin_memory, + ) + + def save_model(self, index): + checkpoint_dir = sorted(Path(self.training_args.output_dir).glob("checkpoint-*")) + if len(checkpoint_dir) >= self.training_args.save_total_limit: + os.rmdir(checkpoint_dir[0]) + + output_dir = os.path.join(self.training_args.output_dir, f"checkpoint-{index}") + if not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + state_dict = OrderedDict() + for n, p in self.model.module.named_parameters(): + state_dict[n] = (p.ds_tensor.detach().cpu(), p.ds_numel, p.ds_shape) + # save model shards + if self.training_args.local_rank != 0: + with open(os.path.join(output_dir, f'pytorch_model-{self.training_args.local_rank}.bin'), 'wb') as f: + torch.save(state_dict, f) + torch.distributed.barrier() + # merge model shards + if self.training_args.local_rank == 0: + # save config + self.model.module.config.save_pretrained(output_dir) + for rank in range(1, self.training_args.world_size): + with open(os.path.join(output_dir, f'pytorch_model-{rank}.bin'), 'rb') as f: + state_dict_rank = torch.load(f) + for n in state_dict_rank: + state_dict[n] = ( + torch.cat([state_dict[n][0], state_dict_rank[n][0]], dim=0), + state_dict[n][1], + state_dict[n][2] + ) + # remove shard files + os.remove(os.path.join(output_dir, f'pytorch_model-{rank}.bin')) + # reshape to original shape + for n in state_dict: + numel = state_dict[n][1] + shape = state_dict[n][2] + state_dict[n] = state_dict[n][0][:numel].view(shape) + + # save inv_freq for llama + if self.model.module.config.model_type == "llama": + num_layers = self.model.module.config.num_hidden_layers + head_dim = self.model.module.config.hidden_size // self.model.module.config.num_attention_heads + base = 10000.0 + inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) + for layer in range(num_layers): + state_dict[f'model.layers.{layer}.self_attn.rotary_emb.inv_freq'] = inv_freq + + with open(os.path.join(output_dir, f'pytorch_model.bin'), 'wb') as f: + torch.save(state_dict, f) + print(f"Save model to {output_dir}.") + + # save lora + self.model.module.peft_config['default'].save_pretrained(output_dir) + # if state dict is not what you expected, you can use the following code to get the state dict + # engine_state_dict = self.model._zero3_consolidated_16bit_state_dict() + lora_state_dict = get_peft_model_state_dict(self.model.module, state_dict) + torch.save(lora_state_dict, os.path.join(output_dir, "adapter_model.bin")) + print(f"Save adapter model at {output_dir}") + + torch.distributed.barrier() diff --git a/src/mydatasets.py b/src/mydatasets.py new file mode 100644 index 0000000..84c7cf0 --- /dev/null +++ b/src/mydatasets.py @@ -0,0 +1,333 @@ +import os +import copy +import json +import random +from tqdm import tqdm +from typing import Callable, Any + +from datasets import load_dataset +from dataclasses import dataclass +import numpy as np +import torch +from torch.utils.data import Dataset + +from log import print +from prompts import QuestionPart, Exemplar, idx_to_ltr + +IGNORE_INDEX = -100 +REPRODUCIBILITY_SEED = 0 + + +class MyDataset(Dataset): + def __init__(self, data_args, tokenizer, dataset_info, split): + super().__init__() + self.data_args = data_args + self.tokenizer = tokenizer + self.split = split + self.sample_size = dataset_info.sample_size + self.prompt_type = dataset_info.prompt_type + + save_dir = os.path.join(data_args.data_dir, data_args.dataset_name, data_args.data_tag) + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + + save_file = os.path.join(save_dir, f'{split}.pt') + if data_args.refresh or not os.path.exists(save_file): + dataset = load_dataset(dataset_info.path, name=dataset_info.name, split=split) + self.data = self.process(dataset_info.extractor, dataset, save_file) + else: + print('Loading data from', save_file) + self.data = torch.load(save_file) + print('Data size:', len(self.data)) + print('Data format:', self.data[0]) + print('Max length:', max([len(d['input_ids']) for d in self.data])) if self.split == 'train' else \ + print('Max length:', max([max([len(d) for d in dd['input_ids']]) for dd in self.data])) + + def process(self, extractor, dataset, save_file): + data = [] + for instance in tqdm(dataset): + exemplar = Exemplar(**extractor(instance)) + if self.prompt_type == 'brown': + prompt = exemplar.get_brown_prompt() + else: + prompt = exemplar.get_natural_prompt() + source = prompt['source'] + + targets = [] + + def _tokenize_fn(source, target): + targets.append(target) + example = f"{source}{target}" + example_tokenized = self.tokenizer.encode(example, truncation=True, max_length=self.data_args.data_max_length) + example_tokenized = example_tokenized + [self.tokenizer.eos_token_id] + source_tokenized = self.tokenizer.encode(source) + + input_ids = example_tokenized + labels = copy.deepcopy(input_ids) + if not self.data_args.train_on_inputs: + labels = np.array(labels) + labels[:len(source_tokenized) - 1] = IGNORE_INDEX + return input_ids, labels + + if self.split == 'train': + input_ids, labels = _tokenize_fn(source, prompt['target']) + else: + input_ids = [] + labels = [] + for choice in prompt['choices']: + op_input_ids, op_labels = _tokenize_fn(source, choice) + input_ids.append(op_input_ids) + labels.append(op_labels) + + data.append({'input_ids': input_ids, + 'labels': labels, + 'source': source, + 'target': targets, + 'answer': exemplar.answer_idx}) + + if self.sample_size > 0 and len(data) > self.sample_size: + random.seed(REPRODUCIBILITY_SEED) + possible_idxs = list(range(len(data))) + sampled_idxs = random.sample(possible_idxs, self.sample_size) + data = [data[i] for i in sampled_idxs] + print(f'Sampled {self.sample_size} examples from {len(possible_idxs)} examples.') + + torch.save(data, save_file) + print('Saving data to', save_file) + return data + + def concat_exemplars(self, exemplars): + exemplar_prompts = [f"{e['source']}{e['target'][0]}" for e in exemplars] + exemplars = "\n\n".join(exemplar_prompts) + return exemplars + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return { + 'input_ids': self.data[idx]['input_ids'], + 'labels': self.data[idx]['labels'] + } + + +@dataclass +class DatasetInfo: + path: str = None + exemplar_split: str = None + eval_split: str = None + test_split: str = None + extractor: Callable = Any + name: str = None + data_dir: str = None + sample_size: int = -1 + prompt_type: str = 'brown' + + +def get_dataset_info(dataset_name): + if dataset_name == 'boolq': + return DatasetInfo( + path="super_glue", + name="boolq", + exemplar_split="train", + eval_split="validation", + sample_size=1000, + extractor=lambda row: { + "parts": [ + QuestionPart( + f"{row['passage']} {row['question']}", + ), + ], + "choices": [ + 'No', 'Yes' + ], + "answer_idx": int(row["label"]) + } + ) + # elif dataset_name == 'cb': + # return DatasetInfo( + # path="super_glue", + # name="cb", + # exemplar_split="train", + # eval_split="validation", + # sample_size=1000, + # extractor=lambda row: { + # "parts": [ + # QuestionPart( + # f"Suppose {row['premise']} Can we infer that \"{row['hypothesis']}\"? Yes, No, or Maybe?", + # ), + # ], + # "choices": [ + # 'Yes', 'No', 'Maybe' + # ], + # "answer_idx": int(row["label"]) + # } + # ) + elif dataset_name == 'multirc': + return DatasetInfo( + path="super_glue", + name="multirc", + exemplar_split="train", + eval_split="validation", + sample_size=1000, + extractor=lambda row: { + "parts": [ + QuestionPart( + f"{row['paragraph']}", + ), + QuestionPart( + f"{row['question']}", + tag='Question' + ), + QuestionPart( + f'I found this answer "{row["answer"]}". Is that correct? Yes or No?', + ), + ], + "choices": [ + 'No', 'Yes' + ], + "answer_idx": int(row["label"]) + } + ) + elif dataset_name == 'rte': + return DatasetInfo( + path="super_glue", + name="rte", + exemplar_split="train", + eval_split="validation", + sample_size=1000, + extractor=lambda row: { + "parts": [ + QuestionPart( + f"{row['premise']}\nDoes this mean that \"{row['hypothesis']}\" is true? Yes or No?", + ), + ], + "choices": [ + 'Yes', 'No' + ], + "answer_idx": int(row["label"]) + } + ) + elif dataset_name == 'wic': + return DatasetInfo( + path="super_glue", + name="wic", + exemplar_split="train", + eval_split="validation", + sample_size=1000, + extractor=lambda row: { + "parts": [ + QuestionPart( + f"Does the word \"{row['word']}\" have the same meaning in these two sentences? Yes, No?\n{row['sentence1']}\n{row['sentence2']}", + ), + ], + "choices": [ + 'No', 'Yes' + ], + "answer_idx": int(row["label"]) + } + ) + elif dataset_name == 'wsc': + return DatasetInfo( + path="super_glue", + name="wsc", + exemplar_split="train", + eval_split="validation", + sample_size=1000, + extractor=lambda row: { + "parts": [ + QuestionPart( + f"{row['text']}\nIn the previous sentence, does the pronuon \"{row['span2_text']}\" refer to \"{row['span1_text']}\"? Yes or No?", + ), + ], + "choices": [ + 'No', 'Yes' + ], + "answer_idx": int(row["label"]) + } + ) + elif dataset_name == 'copa': + return DatasetInfo( + path="super_glue", + name="copa", + exemplar_split="train", + eval_split="validation", + sample_size=1000, + prompt_type='natural', + extractor=lambda row: { + "parts": [ + QuestionPart( + f"{row['premise']} so " if row['question'] == 'effect' else f"{row['premise']} because ", + ), + ], + "choices": [ + row['choice1'], row['choice2'] + ], + "answer_idx": int(row["label"]) + } + ) + elif dataset_name == 'record': + return DatasetInfo( + path="super_glue", + name="record", + exemplar_split="train", + eval_split="validation", + sample_size=1000, + extractor=process_record + ) + else: + raise NotImplementedError + + +def process_record(row): + def record_clean_choices(row): + if len(row['answers']) == 1: + return row['entities'], row['entities'].index(row['answers'][0]) + + new_entities = [] + for entity in row['entities']: + if entity in row['answers'][1:]: + continue + new_entities.append(entity) + return new_entities, new_entities.index(row['answers'][0]) + + choices, answer_idx = record_clean_choices(row) + return { + "parts": [ + QuestionPart( + "{}\n{}\nQuestion: What is the \"@placeholder\"?".format(row['passage'].replace('@highlight\n', '- '), row['query']), + ), + ], + "choices": choices, + "answer_idx": answer_idx + } + + +if __name__ == '__main__': + from transformers import HfArgumentParser + from arguments import ModelArguments, DataArguments + from transformers import AutoTokenizer + + parser = HfArgumentParser((ModelArguments, DataArguments)) + model_args, data_args = parser.parse_args_into_dataclasses() + model_args.model_name_or_path = '/home/klv/llama_hf/7B' + data_args.dataset_name = 'record' + data_args.refresh = True + data_args.data_tag = 'debug' + train_on_inputs = False + data_args.data_max_length = 512 + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + use_fast=False, + padding_side='left' + ) + tokenizer.pad_token_id = 0 + + dataset_info = get_dataset_info(data_args.dataset_name) + train_dataset = MyDataset(data_args, tokenizer, dataset_info, split=dataset_info.exemplar_split) + eval_dataset = MyDataset(data_args, tokenizer, dataset_info, split=dataset_info.eval_split) + # test_dataset = MyDataset(data_args, tokenizer, dataset_info, split=dataset_info.test_split) + + + diff --git a/src/mytrainer.py b/src/mytrainer.py new file mode 100644 index 0000000..ea91c8f --- /dev/null +++ b/src/mytrainer.py @@ -0,0 +1,26 @@ +import tqdm + +import torch +from torch.nn import CrossEntropyLoss +from transformers.trainer_pt_utils import nested_numpify, nested_concat +from src.inplace_zero_trainer import InplaceZeroTrainer + +IGNORE_INDEX = -100 + + +class MyInplaceZeroTrainer(InplaceZeroTrainer): + def eval_step(self, batch): + outs = self.model(batch['input_ids'].cuda(), batch['attention_mask'].cuda()) + # Shift so that tokens < n predict n + shift_logits = outs.logits[..., :-1, :].contiguous() + shift_labels = batch['labels'][..., 1:].cuda().contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(reduction='none') + loss = loss_fct(shift_logits.view(shift_labels.shape[0] * shift_labels.shape[1], -1), + shift_labels.view(-1)).view_as(shift_labels) + loss = loss.mean(dim=1) + group_loss = loss.split(batch['split_size']) + preds = torch.stack([torch.argmin(l) for l in group_loss], dim=0) + + preds = nested_numpify(preds) + return preds.tolist() diff --git a/src/mytrainer_lora.py b/src/mytrainer_lora.py new file mode 100644 index 0000000..606b13d --- /dev/null +++ b/src/mytrainer_lora.py @@ -0,0 +1,26 @@ +import tqdm + +import torch +from torch.nn import CrossEntropyLoss +from transformers.trainer_pt_utils import nested_numpify, nested_concat +from src.inplace_zero_trainer_lora import InplaceZeroTrainer + +IGNORE_INDEX = -100 + + +class MyInplaceZeroTrainer(InplaceZeroTrainer): + def eval_step(self, batch): + outs = self.model(batch['input_ids'].cuda(), batch['attention_mask'].cuda()) + # Shift so that tokens < n predict n + shift_logits = outs.logits[..., :-1, :].contiguous() + shift_labels = batch['labels'][..., 1:].cuda().contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(reduction='none') + loss = loss_fct(shift_logits.view(shift_labels.shape[0] * shift_labels.shape[1], -1), + shift_labels.view(-1)).view_as(shift_labels) + loss = loss.mean(dim=1) + group_loss = loss.split(batch['split_size']) + preds = torch.stack([torch.argmin(l) for l in group_loss], dim=0) + + preds = nested_numpify(preds) + return preds.tolist() diff --git a/src/prompts.py b/src/prompts.py new file mode 100644 index 0000000..d6214a9 --- /dev/null +++ b/src/prompts.py @@ -0,0 +1,88 @@ +from dataclasses import dataclass + +import random + + +def idx_to_ltr(idx): + return chr(idx + ord("A")) + + +@dataclass +class QuestionPart: + text: str + tag: str = None + + def __str__(self): + if self.tag is not None: + return f"{self.tag}: {self.text}" + else: + return self.text + + +@dataclass +class Question: + parts: list + choices: list + answer_idx: int + task: str = None + + def get_n_choices(self): + return len(self.choices) + + def get_answer_str(self): + return self.choices[self.answer_idx] + + def _get_prompt(self, include_choices): + prompt = "" + for part in self.parts: + prompt += f"{str(part)}\n" + if include_choices: + for i, choice in enumerate(self.choices): + prompt += f"{idx_to_ltr(i)}. {choice}\n" + return prompt + + def get_natural_prompt(self): + return self._get_prompt(include_choices=True) + + def get_brown_prompt(self): + return self._get_prompt(include_choices=False) + + def strong_shuffle(self): + # This method shuffles choices such that choosing + # the answer at the originally correct + # index will mean getting the question wrong + + # For degenerate questions where all choices are the same + if len(set(self.choices)) == 1: + return + + answer_idx = self.answer_idx + answer_str = self.get_answer_str() + while self.choices[answer_idx] == answer_str: + random.shuffle(self.choices) + self.answer_idx = self.choices.index(answer_str) + + def permute_choices(self, perm): + self.choices = [self.choices[i] for i in perm] + self.answer_idx = perm.index(self.answer_idx) + + +class Exemplar(Question): + + def get_natural_prompt(self): + prompt = super().get_brown_prompt().strip('\n') + # return f"{prompt} {self.get_answer_str()}" + return { + 'source': f"{prompt}", + 'target': f"{self.get_answer_str()}", + 'choices': self.choices + } + + def get_brown_prompt(self): + prompt = super().get_brown_prompt() + # return f"{prompt} {self.get_answer_str()}" + return { + 'source': f"{prompt}Answer: ", + 'target': f"{self.get_answer_str()}", + 'choices': self.choices + } diff --git a/src/train_zero.py b/src/train_zero.py new file mode 100644 index 0000000..b7579d6 --- /dev/null +++ b/src/train_zero.py @@ -0,0 +1,127 @@ +import copy +import os +import sys + +from random import sample + +import torch +from torch.utils.data import Subset +from transformers import HfArgumentParser +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig +from transformers import set_seed +from dataclasses import asdict +from transformers.deepspeed import HfDeepSpeedConfig +import wandb +# os.environ['WANDB_MODE'] = 'debug' + +python_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +print("PYTHON_PATH", python_path) +sys.path.append(python_path) +from log import print +from arguments import ModelArguments, DataArguments, MyTrainingArguments +from mydatasets import MyDataset, get_dataset_info +from mytrainer import MyInplaceZeroTrainer +from utils import DataCollatorForCauselLM, EvalDataCollatorForCauselLM + + +def compute_metrics(all_pred, eval_dataset, eval_prefix=None): + golds = [ins['answer'] for ins in eval_dataset.data] + preds = all_pred[:len(golds)] + print(len(all_pred)) + print(all_pred[:8]) + print(all_pred[-8:]) + assert len(preds) == len(golds), f"# of predictions {len(preds)} doesn't match # of references {len(golds)}." + + acc = round(sum([int(pred == gold) for pred, gold in zip(preds, golds)]) / len(golds), 6) + result = {'acc': acc} + return result + + +def train(): + # ========== 1. logs and args ========== + torch.set_default_dtype(torch.float16) + parser = HfArgumentParser((ModelArguments, DataArguments, MyTrainingArguments)) + if sys.argv[-1].endswith(".yaml"): + model_args, data_args, training_args = parser.parse_yaml_file(yaml_file=os.path.abspath(sys.argv[-1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + set_seed(training_args.seed) + + model_name = model_args.model_name_or_path.split('/')[-1] + tag_name = '_'.join([data_args.dataset_name, model_name, training_args.tag] if training_args.tag else [data_args.dataset_name, model_name]) + hparam_name = 'output' + if training_args.optim != 'sgd': + hparam_name += '_' + training_args.optim + if training_args.learning_rate != 5e-4: + hparam_name += '_lr' + str(training_args.learning_rate) + if training_args.per_device_train_batch_size != 8: + hparam_name += '_bs' + str(training_args.per_device_train_batch_size) + if training_args.lr_scheduler_type != 'linear': + hparam_name += '_' + training_args.lr_scheduler_type + if training_args.warmup != 0: + hparam_name += '_warmup' + str(training_args.warmup) + if training_args.clip_grad_norm and training_args.clip_grad_norm > 0: + hparam_name += '_clipnorm' + str(training_args.clip_grad_norm) + if training_args.clip_grad_value and training_args.clip_grad_value > 0: + hparam_name += '_clipgrad' + str(training_args.clip_grad_value) + if training_args.clip_loss_value and training_args.clip_loss_value > 0: + hparam_name += '_cliploss' + str(training_args.clip_loss_value) + # assert training_args.clip_grad_value is None or training_args.clip_loss_value is None + training_args.output_dir = os.path.join('outputs', tag_name, hparam_name) + + if training_args.tag == 'debug': + os.environ['WANDB_MODE'] = 'offline' + if training_args.local_rank in [-1, 0]: + wandb_config = copy.deepcopy(asdict(training_args)) + wandb_config.update(asdict(model_args)) + wandb_config.update(asdict(data_args)) + wandb.init( + project="collie", + entity='collie_exp', + name=tag_name if hparam_name == 'output' else '_'.join([tag_name, hparam_name.replace('output_', '')]), + config=wandb_config + ) + + # ========== 2. Load pretrained model and tokenizer. ========== + ds_config = training_args.deepspeed + dschf = HfDeepSpeedConfig(ds_config) + config = AutoConfig.from_pretrained(model_args.model_name_or_path) + config.gradient_checkpointing = training_args.gradient_checkpointing + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + local_files_only=True, + config=config, + ) + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + use_fast=False, + padding_side='left' + ) + tokenizer.pad_token_id = 0 + + # ========== 3. Preprocessing the datasets. ========== + dataset_info = get_dataset_info(data_args.dataset_name) + train_dataset = MyDataset(data_args, tokenizer, dataset_info, split=dataset_info.exemplar_split) + eval_dataset = MyDataset(data_args, tokenizer, dataset_info, split=dataset_info.eval_split) + + # ========== 4. Initialize our Trainer. ========== + trainer = MyInplaceZeroTrainer( + model=model, + training_args=training_args, + data_collator={'train': DataCollatorForCauselLM(tokenizer, max_length=data_args.data_max_length, padding_side='left'), + 'eval': EvalDataCollatorForCauselLM(tokenizer, max_length=data_args.data_max_length, padding_side='left')}, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + compute_metrics=compute_metrics, + ) + if training_args.do_train: + trainer.train() + else: + trainer.eval(trainer.global_step, 0, trainer.eval_dataset, trainer.eval_dataloader, 'zero-shot') + + +# run with $torchrun --nproc_per_node 2 train_inplace_tensor.py config/tensor_args.yaml +if __name__ == "__main__": + train() diff --git a/src/train_zero_lora.py b/src/train_zero_lora.py new file mode 100644 index 0000000..130cc3d --- /dev/null +++ b/src/train_zero_lora.py @@ -0,0 +1,172 @@ +import copy +import os +import sys + +from random import sample + +import torch +from torch.utils.data import Subset +from transformers import HfArgumentParser +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig +from transformers import set_seed +from dataclasses import asdict +from transformers.deepspeed import HfDeepSpeedConfig +from peft import get_peft_model, TaskType, LoraConfig +import wandb +# os.environ['WANDB_MODE'] = 'debug' + +python_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +print("PYTHON_PATH", python_path) +sys.path.append(python_path) +from log import print +from arguments import ModelArguments, DataArguments, MyTrainingArguments +from mydatasets import MyDataset, get_dataset_info +from mytrainer_lora import MyInplaceZeroTrainer +from utils import DataCollatorForCauselLM, EvalDataCollatorForCauselLM + + +def compute_metrics(all_pred, eval_dataset, eval_prefix=None): + golds = [ins['answer'] for ins in eval_dataset.data] + preds = all_pred[:len(golds)] # TODO: 验证pred和gold顺序一致 + # assert len(preds) == len(golds), f"# of predictions {len(preds)} doesn't match # of references {len(golds)}." + + acc = round(sum([int(pred == gold) for pred, gold in zip(preds, golds)]) / len(golds), 6) + result = {'acc': acc} + return result + + +def train(): + # ========== 1. logs and args ========== + torch.set_default_dtype(torch.bfloat16) + parser = HfArgumentParser((ModelArguments, DataArguments, MyTrainingArguments)) + if sys.argv[-1].endswith(".yaml"): + model_args, data_args, training_args = parser.parse_yaml_file(yaml_file=os.path.abspath(sys.argv[-1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + set_seed(training_args.seed) + + model_name = model_args.model_name_or_path.split('/')[-1] + tag_name = '_'.join([data_args.dataset_name, model_name, training_args.tag] if training_args.tag else [data_args.dataset_name, model_name]) + hparam_name = 'output' + if training_args.optim != 'sgd': + hparam_name += '_' + training_args.optim + if training_args.learning_rate != 5e-4: + hparam_name += '_lr' + str(training_args.learning_rate) + if training_args.per_device_train_batch_size != 8: + hparam_name += '_bs' + str(training_args.per_device_train_batch_size) + if training_args.lr_scheduler_type != 'linear': + hparam_name += '_' + training_args.lr_scheduler_type + if training_args.warmup != 0: + hparam_name += '_warmup' + str(training_args.warmup) + if training_args.clip_grad_norm and training_args.clip_grad_norm > 0: + hparam_name += '_clipnorm' + str(training_args.clip_grad_norm) + if training_args.clip_grad_value and training_args.clip_grad_value > 0: + hparam_name += '_clipgrad' + str(training_args.clip_grad_value) + if training_args.clip_loss_value and training_args.clip_loss_value > 0: + hparam_name += '_cliploss' + str(training_args.clip_loss_value) + # assert training_args.clip_grad_value is None or training_args.clip_loss_value is None + training_args.output_dir = os.path.join('outputs', tag_name, hparam_name) + + if training_args.tag == 'debug': + os.environ['WANDB_MODE'] = 'offline' + if training_args.local_rank in [-1, 0]: + wandb_config = copy.deepcopy(asdict(training_args)) + wandb_config.update(asdict(model_args)) + wandb_config.update(asdict(data_args)) + wandb.init( + project="collie", + entity='collie_exp', + name=tag_name if hparam_name == 'output' else '_'.join([tag_name, hparam_name.replace('output_', '')]), + config=wandb_config + ) + + # ========== 2. Load pretrained model and tokenizer. ========== + ds_config = training_args.deepspeed + dschf = HfDeepSpeedConfig(ds_config) + config = AutoConfig.from_pretrained(model_args.model_name_or_path) + config.gradient_checkpointing = training_args.gradient_checkpointing + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + local_files_only=True, + config=config, + ) + + peft_params = [] + non_peft_names = [] + non_peft_params = [] + for name, param in model.named_parameters(): + if param.requires_grad is False: + continue + non_peft_names.append(name) + non_peft_params.append(param) + + # use peft + if training_args.peft_type is not None: + print(f'Using peft.{training_args.peft_type}') + if training_args.peft_type == 'lora': + peft_config = LoraConfig( + r=training_args.lora_r, + lora_alpha=training_args.lora_alpha, + target_modules=["q_proj", "v_proj"], + # target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], + lora_dropout=training_args.lora_dropout, + bias="none", + task_type=TaskType.CAUSAL_LM + ) + model.enable_input_require_grads() + else: + raise ValueError(f"Unknown PEFT type: {training_args.peft_type}") + model = get_peft_model(model, peft_config) + model.print_trainable_parameters() + + # unfreeze base model + # 包完peft之后的参数名字:base_model.model.model.layers.23.self_attn.v_proj.weight + # 之前的参数的名字:model.layers.23.self_attn.v_proj.weight + for name, param in model.named_parameters(): + if name.split('base_model.model.')[1] in non_peft_names: + if not training_args.lora_only: + param.requires_grad = True + if "lora_" in name: + peft_params.append(param) + + torch.cuda.empty_cache() + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + use_fast=False, + padding_side='left' + ) + tokenizer.pad_token_id = 0 + + # ========== 3. Preprocessing the datasets. ========== + dataset_info = get_dataset_info(data_args.dataset_name) + train_dataset = MyDataset(data_args, tokenizer, dataset_info, split=dataset_info.exemplar_split) + # if data_args.few_shot_size != -1: + # # few_shot_indices = sample(range(len(train_dataset)), data_args.few_shot_size) + # train_dataset = Subset(train_dataset, range(data_args.few_shot_size)) + eval_dataset = MyDataset(data_args, tokenizer, dataset_info, split=dataset_info.eval_split) + if dataset_info.test_split: + test_dataset = MyDataset(data_args, tokenizer, dataset_info, split=dataset_info.test_split) + eval_dataset = { + # 'validation': eval_dataset, + 'test': test_dataset + } + + # ========== 4. Initialize our Trainer. ========== + trainer = MyInplaceZeroTrainer( + model=model, + training_args=training_args, + data_collator={'train': DataCollatorForCauselLM(tokenizer, max_length=data_args.data_max_length, padding_side='left'), + 'eval': EvalDataCollatorForCauselLM(tokenizer, max_length=data_args.data_max_length, padding_side='left')}, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + compute_metrics=compute_metrics, + optimizers={'model_parameters': peft_params}, + ) + trainer.train() + + +# run with $torchrun --nproc_per_node 2 train_inplace_tensor.py config/tensor_args.yaml +if __name__ == "__main__": + train() diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..3f4715f --- /dev/null +++ b/src/utils.py @@ -0,0 +1,364 @@ +import copy +from dataclasses import dataclass + +import numpy as np +from transformers.utils import PaddingStrategy +from transformers.trainer import * +import wandb + + +@dataclass +class DataCollatorForCauselLM: + """ + Data collator that will dynamically pad the inputs received, as well as the labels. + Args: + tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]): + The tokenizer used for encoding the data. + model ([`PreTrainedModel`]): + The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to + prepare the *decoder_input_ids* + This is useful when using *label_smoothing* to avoid calculating loss twice. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding index) + among: + - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single + sequence is provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + label_pad_token_id (`int`, *optional*, defaults to -100): + The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions). + return_tensors (`str`): + The type of Tensor to return. Allowable values are "np", "pt" and "tf". + """ + + tokenizer: Any + model: Optional[Any] = None + padding: Union[bool, str, PaddingStrategy] = True + max_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + label_pad_token_id: int = -100 + return_tensors: str = "pt" + padding_side: str = 'right' + + def __call__(self, features, return_tensors=None): + padding_side = self.padding_side + + # if return_tensors is None: + # return_tensors = self.return_tensors + labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None + # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the + # same length to return tensors. + if labels is not None: + max_label_length = max(len(l) for l in labels) + if self.pad_to_multiple_of is not None: + max_label_length = ( + (max_label_length + self.pad_to_multiple_of - 1) + // self.pad_to_multiple_of + * self.pad_to_multiple_of + ) + + for feature in features: + remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"])) + if isinstance(feature["labels"], list): + feature["labels"] = ( + feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"] + ) + elif padding_side == "right": + feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64) + else: + feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64) + + max_length = max(len(feature['input_ids']) for feature in features) + if padding_side == 'right': + input_ids = [feature['input_ids'] + [self.tokenizer.pad_token_id] * (max_length - len(feature['input_ids'])) + for feature in features] + attention_mask = [[1] * len(feature['input_ids']) + [0] * (max_length - len(feature['input_ids'])) for + feature in features] + elif padding_side == 'left': + input_ids = [[self.tokenizer.pad_token_id] * (max_length - len(feature['input_ids'])) + feature['input_ids'] + for feature in features] + attention_mask = [[0] * (max_length - len(feature['input_ids'])) + [1] * len(feature['input_ids']) for + feature in features] + else: + raise ValueError("Invalid padding strategy:" + str(padding_side)) + + features = { + 'input_ids': torch.tensor(input_ids).long(), + 'attention_mask': torch.tensor(attention_mask).long(), + 'labels': torch.tensor(np.array([feature['labels'] for feature in features])).long() + } + return features + + +@dataclass +class EvalDataCollatorForCauselLM: + """ + Data collator that will dynamically pad the inputs received, as well as the labels. + Args: + tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]): + The tokenizer used for encoding the data. + model ([`PreTrainedModel`]): + The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to + prepare the *decoder_input_ids* + This is useful when using *label_smoothing* to avoid calculating loss twice. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding index) + among: + - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single + sequence is provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + label_pad_token_id (`int`, *optional*, defaults to -100): + The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions). + return_tensors (`str`): + The type of Tensor to return. Allowable values are "np", "pt" and "tf". + """ + + tokenizer: Any + model: Optional[Any] = None + padding: Union[bool, str, PaddingStrategy] = True + max_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + label_pad_token_id: int = -100 + return_tensors: str = "pt" + padding_side: str = 'left' + unconditional_normalization: bool = False + + def __call__(self, features, return_tensors=None): + padding_side = self.padding_side + + split_size = [] + new_features = [] + assert "labels" in features[0].keys() + for feature in features: + split_size.append(len(feature["labels"])) + for op_input_ids, op_labels in zip(feature["input_ids"], feature["labels"]): + un_mask = np.zeros_like(op_labels) + un_mask_index = np.where(op_labels == self.label_pad_token_id, 1, 0).sum() - 2 + un_mask[:un_mask_index] = 1 + new_features.append({"input_ids": op_input_ids, "labels": op_labels, "un_mask": un_mask}) + + labels = [feature["labels"] for feature in new_features] + # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the + # same length to return tensors. + if labels is not None: + max_label_length = max(len(l) for l in labels) + if self.pad_to_multiple_of is not None: + max_label_length = ( + (max_label_length + self.pad_to_multiple_of - 1) + // self.pad_to_multiple_of + * self.pad_to_multiple_of + ) + + for feature in new_features: + remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"])) + if isinstance(feature["labels"], list): + feature["labels"] = ( + feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"] + ) + elif padding_side == "right": + feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64) + feature["un_mask"] = np.concatenate([feature["un_mask"], np.ones_like(remainder)]).astype(np.int64) + else: + feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64) + feature["un_mask"] = np.concatenate([np.ones_like(remainder), feature["un_mask"]]).astype(np.int64) + + max_length = max(len(feature['input_ids']) for feature in new_features) + if padding_side == 'right': + input_ids = [feature['input_ids'] + [self.tokenizer.pad_token_id] * (max_length - len(feature['input_ids'])) + for feature in new_features] + attention_mask = [[1] * len(feature['input_ids']) + [0] * (max_length - len(feature['input_ids'])) for + feature in new_features] + elif padding_side == 'left': + input_ids = [[self.tokenizer.pad_token_id] * (max_length - len(feature['input_ids'])) + feature['input_ids'] + for feature in new_features] + attention_mask = [[0] * (max_length - len(feature['input_ids'])) + [1] * len(feature['input_ids']) for + feature in new_features] + else: + raise ValueError("Invalid padding strategy:" + str(padding_side)) + + batched_features = { + 'input_ids': torch.tensor(input_ids).long(), + 'attention_mask': torch.tensor(attention_mask).long(), + 'labels': torch.tensor(np.array([feature['labels'] for feature in new_features])).long(), + 'split_size': split_size + } + if self.unconditional_normalization: + batched_features['un_mask'] = torch.tensor(np.array([feature['un_mask'] for feature in new_features])).bool() + + return batched_features + + +class LearningRateScheduler: + r""" + Learning rate scheduler with warmup. + + :param warmup: if ``warmup`` is an integer, ``warmup`` stands for warmup steps, if ``warmup`` is a float, + such as 0.1, then it stands for warmup_ratio. + :param schedule: the learning rate will be adjusted according to ``schedule`` strategy, + which can be: linear or constant. + """ + + def __init__(self, + warmup: float, + schedule: str, + learning_rate: float, + n_steps: int = 0): + + self.warmup = max(warmup, 0.) + self.schedule = schedule + self.initial_lr = learning_rate + + if self.warmup > 1: + self.warmup = self.warmup / n_steps + self.t_steps = max(2, n_steps) + + if self.schedule == 'constant': + self.get_lr = self._get_constant_lr + elif self.schedule == 'linear': + self.get_lr = self._get_linear_lr + else: + raise NotImplementedError("Only support 'linear', 'constant'.") + + def _get_constant_lr(self, progress): + if progress < self.warmup: + return progress / self.warmup + return 1 + + def _get_linear_lr(self, progress): + if progress < self.warmup: + return progress / self.warmup + return max((progress - 1.) / (self.warmup - 1.), 0.) + + def step(self, global_step): + progress = global_step / self.t_steps + return self.initial_lr * self.get_lr(progress) + + +class WandbLogger: + """ + 使用 wandb 记录信息的类。 + + :param training_args: Trainer 的参数 + """ + + def __init__(self, training_args): + self.training_args = training_args + # report_to is a list + self.able = "wandb" in getattr(training_args, "report_to", []) + if self.able and 'wandb' not in sys.modules: + raise ModuleNotFoundError( + "Detected Wandb not installed while you have set " + "`report_to=['wandb']` in your training config. Please " + "either set `report_to` to another value or install wandb.") + + def init(self, *args, **kwargs): + if self.able: + wandb.init(*args, **kwargs) + + def log(self, *args, **kwargs): + if self.able: + wandb.log(*args, **kwargs) + + def set_summary(self, key, value): + if self.able: + wandb.run.summary[key] = value + + +class DynamicLossScaler: + def __init__(self, + init_scale=2 ** 32, + scale_factor=2., + scale_window=1000, + min_scale=1, + delayed_shift=1, + consecutive_hysteresis=False, + raise_error_at_min_scale=True, + dtype=torch.half): + self.cur_scale = init_scale + self.cur_iter = 0 + self.last_overflow_iter = -1 + self.scale_factor = scale_factor + self.scale_window = scale_window + self.min_scale = min_scale + self.delayed_shift = delayed_shift + self.cur_hysteresis = delayed_shift + self.consecutive_hysteresis = consecutive_hysteresis + self.raise_error_at_min_scale = raise_error_at_min_scale + self.dtype = dtype + self.has_overflow_serial = False + + @property + def loss_scale(self): + return self.cur_scale + + # `x` is a torch.Tensor + def _has_inf_or_nan(self, x): + try: + # if x is half, the .float() incurs an additional deep copy, but it's necessary if + # Pytorch's .sum() creates a one-element tensor of the same type as x + # (which is true for some recent version of pytorch). + cpu_sum = float(x.float().sum()) + # More efficient version that can be used if .sum() returns a Python scalar + # cpu_sum = float(x.sum()) + except RuntimeError as instance: + # We want to check if inst is actually an overflow exception. + # RuntimeError could come from a different error. + # If so, we still want the exception to propagate. + if "value cannot be converted" not in instance.args[0]: + raise + return True + else: + if cpu_sum in [float('inf'), -float('inf')] or cpu_sum != cpu_sum: + return True + return False + + # `overflow` is boolean indicating whether the gradient overflowed + def update_scale(self, overflow): + if overflow: + # self.cur_scale /= self.scale_factor + if self.delayed_shift == 1 or self.cur_hysteresis == 1: + if (self.cur_scale == self.min_scale) and self.raise_error_at_min_scale: + raise Exception( + "Current loss scale already at minimum - cannot decrease scale anymore. Exiting run.") + else: + next_scale = max(self.cur_scale / self.scale_factor, self.min_scale) + if torch.distributed.get_rank() == 0: + overflow_msg = f"[deepspeed] OVERFLOW! Rank {torch.distributed.get_rank()} Skipping step." + if self.dtype == torch.half: + overflow_msg += f" Attempted loss scale: {int(self.cur_scale)}, reducing to {int(next_scale)}" + print(overflow_msg) + self.cur_scale = next_scale + else: + if torch.distributed.get_rank() == 0: + overflow_msg = f"[deepspeed] OVERFLOW! Rank {torch.distributed.get_rank()} Skipping step." + if self.dtype == torch.half: + overflow_msg += f" Attempted loss scale: {int(self.cur_scale)}, but hysteresis is {self.cur_hysteresis}. Reducing hysteresis to {self.cur_hysteresis - 1}" + print(overflow_msg) + self.cur_hysteresis -= 1 + self.last_overflow_iter = self.cur_iter + else: + if self.consecutive_hysteresis: + if torch.distributed.get_rank() == 0: + hysteresis_msg = f"Consecutive hysteresis is enabled. Restoring hysteresis to {self.delayed_shift}" + print(hysteresis_msg) + self.cur_hysteresis = self.delayed_shift + if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0: + if not self.consecutive_hysteresis: + self.cur_hysteresis = self.delayed_shift + self.cur_scale *= self.scale_factor + self.cur_iter += 1