forked from OpenLMLab/LOMO
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
22 changed files
with
2,916 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
.idea |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,11 @@ | ||
# LOMO | ||
LOMO: LOw-Memory Optimization | ||
# LOMO: LOw-Memory Optimization | ||
This is the implementation for [Full Parameter Fine-Tuning for Large Language Models with Limited Resorcces](). | ||
|
||
--- | ||
## Run the code | ||
```shell | ||
bash run.sh | ||
``` | ||
|
||
## Reproduce our results | ||
We provide the sampled datasets used in our experiments [here](https://drive.google.com/drive/folders/1zV7sXvU7YHKWyS3fYV0yyi7FyTjIpEuO?usp=sharing). Due to the limited computational resources, we reported the highest results obtained from experiments conducted with the same random seed (`42`). We acknolwedge this limitation in our work and plan to conduct repeated experiments in the next version to address it. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
{ | ||
|
||
"bf16": { | ||
"enabled": false | ||
}, | ||
"fp16": { | ||
"enabled": true | ||
}, | ||
"zero_allow_untested_optimizer": true, | ||
"zero_force_ds_cpu_optimizer": false, | ||
|
||
"zero_optimization": { | ||
"stage": 3, | ||
"overlap_comm": true, | ||
"contiguous_gradients": true, | ||
"sub_group_size": 1e8, | ||
"stage3_max_live_parameters": 1e8, | ||
"stage3_max_reuse_distance": 1e8, | ||
"stage3_gather_16bit_weights_on_model_save": true | ||
}, | ||
|
||
|
||
"gradient_accumulation_steps": 1, | ||
"steps_per_print": 2000, | ||
"train_micro_batch_size_per_gpu": 2, | ||
"wall_clock_breakdown": false | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
{ | ||
|
||
"bf16": { | ||
"enabled": true | ||
}, | ||
"fp16": { | ||
"enabled": false | ||
}, | ||
"zero_allow_untested_optimizer": true, | ||
"zero_force_ds_cpu_optimizer": false, | ||
|
||
"zero_optimization": { | ||
"stage": 3, | ||
"overlap_comm": true, | ||
"contiguous_gradients": true, | ||
"sub_group_size": 1e8, | ||
"stage3_max_live_parameters": 1e8, | ||
"stage3_max_reuse_distance": 1e8, | ||
"stage3_gather_16bit_weights_on_model_save": true | ||
}, | ||
|
||
|
||
"gradient_accumulation_steps": 1, | ||
"steps_per_print": 2000, | ||
"train_micro_batch_size_per_gpu": 2, | ||
"wall_clock_breakdown": false | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# model | ||
model_name_or_path: '/home/ubuntu/projects/tunelite/cache/llama-7b' | ||
# data | ||
dataset_name: 'wsc' | ||
refresh: false | ||
data_tag: 'base' | ||
train_on_inputs: false | ||
data_max_length: 1024 | ||
# training | ||
# trainer | ||
tag: 'inplace-sgd' | ||
output_dir: 'outputs' | ||
overwrite_output_dir: true | ||
deepspeed: 'config/ds_config.json' | ||
do_train: true | ||
do_eval: true | ||
evaluation_strategy: 'epoch' | ||
per_device_train_batch_size: 2 | ||
per_device_eval_batch_size: 2 | ||
learning_rate: 0.05 | ||
weight_decay: 0 | ||
num_train_epochs: 10 | ||
lr_scheduler_type: 'linear' | ||
warmup: 0.1 | ||
clip_grad_norm: 1.0 | ||
log_level: 'info' | ||
logging_steps: 1 | ||
save_strategy: 'no' | ||
save_total_limit: 0 | ||
seed: 42 | ||
#bf16: true | ||
remove_unused_columns: false | ||
load_best_model_at_end: false | ||
metric_for_best_model: 'acc' | ||
optim: 'sgd' | ||
group_by_length: false | ||
#report_to: 'wandb' | ||
dataloader_pin_memory: false | ||
gradient_checkpointing: true | ||
predict_with_generate: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# model | ||
model_name_or_path: '/home/ubuntu/projects/tunelite/cache/llama-7b' | ||
# data | ||
dataset_name: 'wsc' | ||
refresh: false | ||
data_tag: 'base' | ||
train_on_inputs: false | ||
data_max_length: 1024 | ||
# training | ||
# trainer | ||
peft_type: 'lora' | ||
lora_only: false | ||
hf_learning_rate: 0.0005 | ||
hf_weight_decay: 0 | ||
hf_lr_scheduler_type: 'linear' | ||
hf_warmup: 0.05 | ||
tag: 'lora-qv-r2-inplace-sgd' | ||
output_dir: 'outputs' | ||
overwrite_output_dir: true | ||
deepspeed: 'config/ds_config_lora.json' | ||
do_train: true | ||
do_eval: true | ||
evaluation_strategy: 'epoch' | ||
per_device_train_batch_size: 2 | ||
per_device_eval_batch_size: 2 | ||
learning_rate: 0.005 | ||
weight_decay: 0 | ||
num_train_epochs: 10 | ||
lr_scheduler_type: 'linear' | ||
warmup: 0.05 | ||
clip_grad_norm: 1.0 | ||
#clip_grad_value: 1.0 | ||
#clip_loss_value: 5.0 | ||
log_level: 'info' | ||
logging_steps: 1 | ||
save_strategy: 'no' | ||
save_total_limit: 0 | ||
seed: 42 | ||
#bf16: true | ||
remove_unused_columns: false | ||
load_best_model_at_end: false | ||
metric_for_best_model: 'acc' | ||
optim: 'sgd' | ||
group_by_length: false | ||
#report_to: 'wandb' | ||
dataloader_pin_memory: false | ||
gradient_checkpointing: true | ||
predict_with_generate: true | ||
lora_r: 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
__all__ = [ | ||
'logger', | ||
"print" | ||
] | ||
|
||
from .logger import logger | ||
from .print import print | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import logging | ||
import sys | ||
from logging import getLevelName | ||
|
||
try: | ||
from tqdm.auto import tqdm | ||
except ImportError: | ||
tqdm = None | ||
|
||
__all__ = [] | ||
|
||
if tqdm is not None: | ||
class TqdmLoggingHandler(logging.Handler): | ||
def __init__(self, level=logging.INFO): | ||
super().__init__(level) | ||
|
||
def emit(self, record): | ||
try: | ||
msg = self.format(record) | ||
tqdm.write(msg) | ||
self.flush() | ||
except (KeyboardInterrupt, SystemExit): | ||
raise | ||
except: | ||
self.handleError(record) | ||
else: | ||
class TqdmLoggingHandler(logging.StreamHandler): | ||
def __init__(self, level=logging.INFO): | ||
super().__init__(sys.stdout) | ||
self.setLevel(level) | ||
|
||
|
||
class StdoutStreamHandler(logging.StreamHandler): | ||
""" | ||
重载 StreamHandler 使得替换 sys.stdout 的时候能够生效。 | ||
""" | ||
def __init__(self): | ||
super(StdoutStreamHandler, self).__init__() | ||
|
||
def flush(self): | ||
""" | ||
Flushes the stream. | ||
""" | ||
self.acquire() | ||
try: | ||
sys.stdout.flush() | ||
finally: | ||
self.release() | ||
|
||
def emit(self, record): | ||
""" | ||
Emit a record. | ||
If a formatter is specified, it is used to format the record. | ||
The record is then written to the stream with a trailing newline. If | ||
exception information is present, it is formatted using | ||
traceback.print_exception and appended to the stream. If the stream | ||
has an 'encoding' attribute, it is used to determine how to do the | ||
output to the stream. | ||
""" | ||
try: | ||
msg = self.format(record) | ||
stream = sys.stdout | ||
# issue 35046: merged two stream.writes into one. | ||
stream.write(msg + self.terminator) | ||
self.flush() | ||
except RecursionError: # See issue 36272 | ||
raise | ||
except Exception: | ||
self.handleError(record) | ||
|
||
def setStream(self, stream): | ||
""" | ||
Sets the StreamHandler's stream to the specified value, | ||
if it is different. | ||
Returns the old stream, if the stream was changed, or None | ||
if it wasn't. | ||
""" | ||
raise RuntimeError("Cannot set the stream of FStreamHandler.") | ||
|
||
def __repr__(self): | ||
level = getLevelName(self.level) | ||
name = getattr(sys.stdout, 'name', '') | ||
# bpo-36015: name can be an int | ||
name = str(name) | ||
if name: | ||
name += ' ' | ||
return '<%s %s(%s)>' % (self.__class__.__name__, name, level) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from rich.highlighter import Highlighter | ||
|
||
__all__ = [] | ||
|
||
class ColorHighlighter(Highlighter): | ||
def __init__(self, color='black'): | ||
self.color = color | ||
|
||
def highlight(self, text): | ||
text.stylize(self.color) |
Oops, something went wrong.