Skip to content

Commit

Permalink
upgrading to nbdev2; WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
ohmeow committed Aug 27, 2022
1 parent 8b73691 commit c3d8ca7
Show file tree
Hide file tree
Showing 55 changed files with 9,023 additions and 9,637 deletions.
254 changes: 67 additions & 187 deletions README.md

Large diffs are not rendered by default.

720 changes: 483 additions & 237 deletions blurr/_modidx.py

Large diffs are not rendered by default.

25 changes: 17 additions & 8 deletions blurr/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_callbacks.ipynb.
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00b_callbacks.ipynb.

# %% auto 0
__all__ = ['CheckpointingNotSupported', 'GradientCheckpointing']

# %% ../nbs/00_callbacks.ipynb 5
# %% ../nbs/00b_callbacks.ipynb 4
import os

import importlib, sys, torch
from typing import Any, Callable, Dict, List, Optional, Union, Type

Expand All @@ -14,31 +16,38 @@
from fastai.torch_core import *
from transformers import PreTrainedModel

# %% ../nbs/00_callbacks.ipynb 9

# %% ../nbs/00b_callbacks.ipynb 6
os.environ["TOKENIZERS_PARALLELISM"] = "false"


# %% ../nbs/00b_callbacks.ipynb 10
class CheckpointingNotSupported(Exception):
def __init__(self, msg="Model does not support gradient checkpointing."):
super().__init__(msg)

# %% ../nbs/00_callbacks.ipynb 10

# %% ../nbs/00b_callbacks.ipynb 11
class GradientCheckpointing(Callback):
"""A fastai callback to enable gradient checkpointing for compatible HuggingFace models."""

def before_fit(self):
"""Enable gradient checkpointing on before_fit event."""

# Check that huggingface model supports gradient checkpointing
if not self.model.hf_model.supports_gradient_checkpointing:
raise CheckpointingNotSupported()

if self.model.hf_model.is_gradient_checkpointing == False:
self.model.hf_model.gradient_checkpointing_enable()

def after_fit(self):
"""Disable gradient checkpointing on after_fit event."""
if self.model.hf_model.is_gradient_checkpointing:
self.model.hf_model.gradient_checkpointing_disable()

@staticmethod
def supported(model: PreTrainedModel):
"""Tests whether a HuggingFace `PreTrainedModel` supports gradient checkpointing."""
return model.supports_gradient_checkpointing

12 changes: 8 additions & 4 deletions blurr/examples/text/causal_lm_gpt2.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/99e_text-examples-causal-lm-gpt2.ipynb.
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/99e_text-examples-causal-lm-gpt2.ipynb.

# %% auto 0
__all__ = []

# %% ../nbs/99e_text-examples-causal-lm-gpt2.ipynb 4
# %% ../../../nbs/99e_text-examples-causal-lm-gpt2.ipynb 5
import warnings
from transformers import *
from transformers.utils import logging as hf_logging
from fastai.text.all import *


from ...text.data.core import *
from ...text.data.language_modeling import *
from ...text.modeling.core import *
from ...text.modeling.language_modeling import *
from ...text.utils import *
from ...utils import *

logging.set_verbosity_error()
# %% ../../../nbs/99e_text-examples-causal-lm-gpt2.ipynb 7
# silence all the HF warnings
warnings.simplefilter("ignore")
hf_logging.set_verbosity_error()
12 changes: 8 additions & 4 deletions blurr/examples/text/glue.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/99b_text-examples-glue.ipynb.
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/99b_text-examples-glue.ipynb.

# %% auto 0
__all__ = []

# %% ../nbs/99b_text-examples-glue.ipynb 4
import torch
# %% ../../../nbs/99b_text-examples-glue.ipynb 5
import torch, warnings
from fastai.text.all import *

from datasets import load_dataset, concatenate_datasets
from transformers import *
from transformers.utils import logging as hf_logging

from ...text.data.core import *
from ...text.modeling.core import *
from ...text.utils import *
from ...utils import *

logging.set_verbosity_error()

# %% ../../../nbs/99b_text-examples-glue.ipynb 7
# silence all the HF warnings
warnings.simplefilter("ignore")
hf_logging.set_verbosity_error()
11 changes: 8 additions & 3 deletions blurr/examples/text/glue_low_level_api.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/99c_text-examples-glue-plain-pytorch.ipynb.
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/99c_text-examples-glue-plain-pytorch.ipynb.

# %% auto 0
__all__ = []

# %% ../nbs/99c_text-examples-glue-plain-pytorch.ipynb 4
# %% ../../../nbs/99c_text-examples-glue-plain-pytorch.ipynb 5
import warnings
from dataclasses import dataclass

import torch
from fastai.text.all import *
from transformers import *
from transformers.utils import logging as hf_logging
from datasets import load_dataset

from ...text.data.core import *
from ...text.modeling.core import *
from ...text.utils import *
from ...utils import *

logging.set_verbosity_error()
# %% ../../../nbs/99c_text-examples-glue-plain-pytorch.ipynb 7
# silence all the HF warnings
warnings.simplefilter("ignore")
hf_logging.set_verbosity_error()
12 changes: 8 additions & 4 deletions blurr/examples/text/high_level_api.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/99a_text-examples-high-level-api.ipynb.
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/99a_text-examples-high-level-api.ipynb.

# %% auto 0
__all__ = []

# %% ../nbs/99a_text-examples-high-level-api.ipynb 4
import os
# %% ../../../nbs/99a_text-examples-high-level-api.ipynb 5
import os, warnings

from datasets import load_dataset, concatenate_datasets
from transformers import *
from transformers.utils import logging as hf_logging
from fastai.text.all import *

from ...text.data.core import *
Expand All @@ -21,5 +22,8 @@
from ...text.utils import *
from ...utils import *

logging.set_verbosity_error()

# %% ../../../nbs/99a_text-examples-high-level-api.ipynb 7
# silence all the HF warnings
warnings.simplefilter("ignore")
hf_logging.set_verbosity_error()
12 changes: 8 additions & 4 deletions blurr/examples/text/multilabel_classification.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/99d_text-examples-multilabel.ipynb.
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/99d_text-examples-multilabel.ipynb.

# %% auto 0
__all__ = []

# %% ../nbs/99d_text-examples-multilabel.ipynb 4
import os
# %% ../../../nbs/99d_text-examples-multilabel.ipynb 5
import os, warnings

import datasets
from transformers import *
from transformers.utils import logging as hf_logging
from fastai.text.all import *
from fastai.callback.hook import _print_shapes

Expand All @@ -17,5 +18,8 @@
from ...text.utils import *
from ...utils import *

logging.set_verbosity_error()

# %% ../../../nbs/99d_text-examples-multilabel.ipynb 7
# silence all the HF warnings
warnings.simplefilter("ignore")
hf_logging.set_verbosity_error()
4 changes: 2 additions & 2 deletions blurr/text/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_text-callbacks.ipynb.
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/01_text-callbacks.ipynb.

# %% auto 0
__all__ = []

# %% ../nbs/01_text-callbacks.ipynb 3
# %% ../../nbs/01_text-callbacks.ipynb 3
import importlib, sys, torch
from typing import Any, Callable, Dict, List, Optional, Union, Type

Expand Down
40 changes: 22 additions & 18 deletions blurr/text/data/core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/11_text-data-core.ipynb.
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/11_text-data-core.ipynb.

# %% auto 0
__all__ = ['Preprocessor', 'ClassificationPreprocessor', 'TextInput', 'BatchTokenizeTransform', 'BatchDecodeTransform',
'blurr_sort_func', 'TextBlock', 'get_blurr_tfm', 'first_blurr_tfm', 'show_batch', 'TextBatchCreator',
'TextDataLoader', 'preproc_hf_dataset']

# %% ../nbs/11_text-data-core.ipynb 4
import os, inspect
# %% ../../../nbs/11_text-data-core.ipynb 4
import os, inspect, warnings
from dataclasses import dataclass
from functools import reduce, partial
from typing import Any, Callable, List, Optional, Union, Type
Expand All @@ -26,15 +26,19 @@
PretrainedConfig,
PreTrainedTokenizerBase,
PreTrainedModel,
logging,
)
from transformers.utils import logging as hf_logging

from ..utils import get_hf_objects

logging.set_verbosity_error()

# %% ../../../nbs/11_text-data-core.ipynb 6
# silence all the HF warnings
warnings.simplefilter("ignore")
hf_logging.set_verbosity_error()

# %% ../nbs/11_text-data-core.ipynb 13

# %% ../../../nbs/11_text-data-core.ipynb 15
class Preprocessor:
def __init__(
self,
Expand Down Expand Up @@ -98,7 +102,7 @@ def _tokenize_function(self, example):
return self.hf_tokenizer(txts, txt_pairs, **self.tok_kwargs)


# %% ../nbs/11_text-data-core.ipynb 15
# %% ../../../nbs/11_text-data-core.ipynb 17
class ClassificationPreprocessor(Preprocessor):
def __init__(
self,
Expand Down Expand Up @@ -190,14 +194,14 @@ def _process_df_batch(self, batch_df):
return batch_df


# %% ../nbs/11_text-data-core.ipynb 23
# %% ../../../nbs/11_text-data-core.ipynb 25
class TextInput(TensorBase):
"""The base represenation of your inputs; used by the various fastai `show` methods"""

pass


# %% ../nbs/11_text-data-core.ipynb 26
# %% ../../../nbs/11_text-data-core.ipynb 28
class BatchTokenizeTransform(Transform):
"""
Handles everything you need to assemble a mini-batch of inputs and targets, as well as
Expand Down Expand Up @@ -296,7 +300,7 @@ def encodes(self, samples, return_batch_encoding=False):
return updated_samples


# %% ../nbs/11_text-data-core.ipynb 29
# %% ../../../nbs/11_text-data-core.ipynb 31
class BatchDecodeTransform(Transform):
"""A class used to cast your inputs as `input_return_type` for fastai `show` methods"""

Expand Down Expand Up @@ -327,7 +331,7 @@ def decodes(self, items: dict):
return self.input_return_type(items["input_ids"])


# %% ../nbs/11_text-data-core.ipynb 32
# %% ../../../nbs/11_text-data-core.ipynb 34
def blurr_sort_func(
example,
# A Hugging Face tokenizer
Expand All @@ -343,7 +347,7 @@ def blurr_sort_func(
return len(txt) if is_split_into_words else len(hf_tokenizer.tokenize(txt, **tok_kwargs))


# %% ../nbs/11_text-data-core.ipynb 34
# %% ../../../nbs/11_text-data-core.ipynb 36
class TextBlock(TransformBlock):
"""The core `TransformBlock` to prepare your inputs for training in Blurr with fastai's `DataBlock` API"""

Expand Down Expand Up @@ -438,7 +442,7 @@ def __init__(
return super().__init__(dl_type=dl_type, dls_kwargs={"before_batch": batch_tokenize_tfm}, batch_tfms=batch_decode_tfm)


# %% ../nbs/11_text-data-core.ipynb 37
# %% ../../../nbs/11_text-data-core.ipynb 39
def get_blurr_tfm(
# A list of transforms (e.g., dls.after_batch, dls.before_batch, etc...)
tfms_list: Pipeline,
Expand All @@ -452,7 +456,7 @@ def get_blurr_tfm(
return next(filter(lambda el: issubclass(type(el), tfm_class), tfms_list), None)


# %% ../nbs/11_text-data-core.ipynb 39
# %% ../../../nbs/11_text-data-core.ipynb 41
def first_blurr_tfm(
# Your fast.ai `DataLoaders
dls: DataLoaders,
Expand All @@ -474,7 +478,7 @@ def first_blurr_tfm(
return found_tfm


# %% ../nbs/11_text-data-core.ipynb 42
# %% ../../../nbs/11_text-data-core.ipynb 44
@typedispatch
def show_batch(
# This typedispatched `show_batch` will be called for `TextInput` typed inputs
Expand Down Expand Up @@ -526,7 +530,7 @@ def show_batch(
return ctxs


# %% ../nbs/11_text-data-core.ipynb 73
# %% ../../../nbs/11_text-data-core.ipynb 77
@dataclass
class TextBatchCreator:
"""
Expand Down Expand Up @@ -561,7 +565,7 @@ def __call__(self, features):
return batch


# %% ../nbs/11_text-data-core.ipynb 75
# %% ../../../nbs/11_text-data-core.ipynb 79
@delegates()
class TextDataLoader(TfmdDL):
"""
Expand Down Expand Up @@ -646,7 +650,7 @@ def new(
return super().new(dataset, cls, **kwargs)


# %% ../nbs/11_text-data-core.ipynb 81
# %% ../../../nbs/11_text-data-core.ipynb 85
def preproc_hf_dataset(
# A standard PyTorch Dataset or fast.ai Datasets
dataset: Union[torch.utils.data.dataset.Dataset, Datasets],
Expand Down
Loading

0 comments on commit c3d8ca7

Please sign in to comment.