Skip to content

Commit

Permalink
update conversion scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
thomwolf committed Sep 5, 2019
1 parent d77abd4 commit 121f88c
Show file tree
Hide file tree
Showing 9 changed files with 20 additions and 9 deletions.
11 changes: 11 additions & 0 deletions pytorch_transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
__version__ = "1.2.0"
# Work around to update TensorFlow's absl.logging threshold which alters the
# default Python logging output behavior when present.
# see: https://github.com/abseil/abseil-py/issues/99
# and: https://github.com/tensorflow/tensorflow/issues/26691#issuecomment-500369493
try:
import absl.logging
absl.logging.set_verbosity('info')
absl.logging.set_stderrthreshold('info')
absl.logging._warn_preinit_stderr = False
except:
pass

# Tokenizer
from .tokenization_utils import (PreTrainedTokenizer)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_transformers/convert_gpt2_checkpoint_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import torch

from pytorch_transformers.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME,
from pytorch_transformers import (CONFIG_NAME, WEIGHTS_NAME,
GPT2Config,
GPT2Model,
load_tf_weights_in_gpt2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import torch

from pytorch_transformers.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME,
from pytorch_transformers import (CONFIG_NAME, WEIGHTS_NAME,
OpenAIGPTConfig,
OpenAIGPTModel,
load_tf_weights_in_openai_gpt)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_transformers/convert_pytorch_checkpoint_to_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch
import numpy as np
import tensorflow as tf
from pytorch_transformers.modeling import BertModel
from pytorch_transformers import BertModel


def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_transformers/convert_roberta_checkpoint_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@

from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
from fairseq.modules import TransformerSentenceEncoderLayer
from pytorch_transformers.modeling_bert import (BertConfig, BertEncoder,
from pytorch_transformers import (BertConfig, BertEncoder,
BertIntermediate, BertLayer,
BertModel, BertOutput,
BertSelfAttention,
BertSelfOutput)
from pytorch_transformers.modeling_roberta import (RobertaEmbeddings,
from pytorch_transformers import (RobertaEmbeddings,
RobertaForMaskedLM,
RobertaForSequenceClassification,
RobertaModel)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_transformers/convert_tf_checkpoint_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import argparse
import torch

from pytorch_transformers.modeling_bert import BertConfig, BertForPreTraining, load_tf_weights_in_bert
from pytorch_transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert

import logging
logging.basicConfig(level=logging.INFO)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import pytorch_transformers.tokenization_transfo_xl as data_utils

from pytorch_transformers import CONFIG_NAME, WEIGHTS_NAME
from pytorch_transformers.modeling_transfo_xl import (TransfoXLConfig, TransfoXLLMHeadModel,
from pytorch_transformers import (TransfoXLConfig, TransfoXLLMHeadModel,
load_tf_weights_in_transfo_xl)
from pytorch_transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES)

Expand Down
2 changes: 1 addition & 1 deletion pytorch_transformers/convert_xlm_checkpoint_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch
import numpy

from pytorch_transformers.modeling_utils import CONFIG_NAME, WEIGHTS_NAME
from pytorch_transformers import CONFIG_NAME, WEIGHTS_NAME
from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES

import logging
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import argparse
import torch

from pytorch_transformers.modeling_xlnet import (CONFIG_NAME, WEIGHTS_NAME,
from pytorch_transformers import (CONFIG_NAME, WEIGHTS_NAME,
XLNetConfig,
XLNetLMHeadModel, XLNetForQuestionAnswering,
XLNetForSequenceClassification,
Expand Down

0 comments on commit 121f88c

Please sign in to comment.