Skip to content

Commit

Permalink
commplying with isort
Browse files Browse the repository at this point in the history
  • Loading branch information
VictorSanh committed Jun 1, 2020
1 parent db2a3b2 commit 5c8e5b3
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 28 deletions.
10 changes: 5 additions & 5 deletions examples/movement-pruning/bertarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
as a standard :class:`~transformers.BertForSequenceClassification`.
"""

import argparse
import os
import shutil
import argparse

import torch

from emmental.modules import MagnitudeBinarizer, TopKBinarizer, ThresholdBinarizer
from emmental.modules import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer


def main(args):
Expand All @@ -40,13 +40,13 @@ def main(args):
for name, tensor in model.items():
if "embeddings" in name or "LayerNorm" in name or "pooler" in name:
pruned_model[name] = tensor
print(f"Pruned layer {name}")
print(f"Copied layer {name}")
elif "classifier" in name or "qa_output" in name:
pruned_model[name] = tensor
print(f"Pruned layer {name}")
print(f"Copied layer {name}")
elif "bias" in name:
pruned_model[name] = tensor
print(f"Pruned layer {name}")
print(f"Copied layer {name}")
else:
if pruning_method == "magnitude":
mask = MagnitudeBinarizer.apply(inputs=tensor, threshold=threshold)
Expand Down
4 changes: 2 additions & 2 deletions examples/movement-pruning/counts_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
Count remaining (non-zero) weights in the encoder (i.e. the transformer layers).
Sparsity and remaining weights levels are equivalent: sparsity % = 100 - remaining weights %.
"""
import os
import argparse
import os

import torch

from emmental.modules import TopKBinarizer, ThresholdBinarizer
from emmental.modules import ThresholdBinarizer, TopKBinarizer


def main(args):
Expand Down
8 changes: 3 additions & 5 deletions examples/movement-pruning/emmental/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from .modules import *

from .configuration_bert_masked import MaskedBertConfig

from .modeling_bert_masked import (
MaskedBertModel,
MaskedBertForMultipleChoice,
MaskedBertForQuestionAnswering,
MaskedBertForSequenceClassification,
MaskedBertForTokenClassification,
MaskedBertForMultipleChoice,
MaskedBertModel,
)
from .modules import *
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@

import logging

from transformers.configuration_utils import PretrainedConfig
from transformers.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from transformers.configuration_utils import PretrainedConfig


logger = logging.getLogger(__name__)

Expand Down
11 changes: 7 additions & 4 deletions examples/movement-pruning/emmental/modeling_bert_masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss

from emmental import MaskedBertConfig, MaskedLinear
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_callable
from transformers.modeling_bert import (
ACT2FN,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BertLayerNorm,
load_tf_weights_in_bert,
)
from transformers.modeling_utils import PreTrainedModel, prune_linear_layer
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from transformers.modeling_bert import load_tf_weights_in_bert, ACT2FN, BertLayerNorm

from emmental import MaskedLinear
from emmental import MaskedBertConfig

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion examples/movement-pruning/emmental/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .binarizer import ThresholdBinarizer, TopKBinarizer, MagnitudeBinarizer
from .binarizer import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
from .masked_nn import MaskedLinear
6 changes: 3 additions & 3 deletions examples/movement-pruning/emmental/modules/masked_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
The pruned weight matrix is then multiplied against the inputs (and if necessary, the bias is added).
"""

import math

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import init

import math

from .binarizer import ThresholdBinarizer, TopKBinarizer, MagnitudeBinarizer
from .binarizer import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer


class MaskedLinear(nn.Linear):
Expand Down
6 changes: 3 additions & 3 deletions examples/movement-pruning/masked_run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@

import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange

from emmental import MaskedBertConfig, MaskedBertForSequenceClassification
from transformers import (
WEIGHTS_NAME,
AdamW,
Expand All @@ -43,7 +44,6 @@
from transformers import glue_output_modes as output_modes
from transformers import glue_processors as processors

from emmental import MaskedBertConfig, MaskedBertForSequenceClassification

try:
from torch.utils.tensorboard import SummaryWriter
Expand Down
7 changes: 3 additions & 4 deletions examples/movement-pruning/masked_run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@

import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange

from emmental import MaskedBertConfig, MaskedBertForQuestionAnswering
from transformers import (
WEIGHTS_NAME,
AdamW,
Expand All @@ -48,8 +49,6 @@
from transformers.data.processors.squad import SquadResult, SquadV1Processor, SquadV2Processor


from emmental import MaskedBertConfig, MaskedBertForQuestionAnswering

try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
Expand Down

0 comments on commit 5c8e5b3

Please sign in to comment.