Skip to content

Transformer-based models implemented in tensorflow 2.x(using keras).

License

Notifications You must be signed in to change notification settings

Juzenn/transformers-keras

 
 

Repository files navigation

transformers-keras

Python package PyPI version Python

Transformer-based models implemented in tensorflow 2.x(Keras).

中文文档 | [English]

Contents

Installation

pip install -U transformers-keras

Models

Transformer

Train a new transformer:

from transformers_keras import TransformerTextFileDatasetBuilder
from transformers_keras import TransformerDefaultTokenizer
from transformers_keras import TransformerRunner


src_tokenizer = TransformerDefaultTokenizer(vocab_file='testdata/vocab_src.txt')
tgt_tokenizer = TransformerDefaultTokenizer(vocab_file='testdata/vocab_tgt.txt')
dataset_builder = TransformerTextFileDatasetBuilder(src_tokenizer, tgt_tokenizer)

model_config = {
    'num_encoder_layers': 2,
    'num_decoder_layers': 2,
    'src_vocab_size': src_tokenizer.vocab_size,
    'tgt_vocab_size': tgt_tokenizer.vocab_size,
}

runner = TransformerRunner(model_config, dataset_builder, model_dir='/tmp/transformer')

train_files = [('testdata/train.src.txt','testdata/train.tgt.txt')]
runner.train(train_files, epochs=10, callbacks=None)

BERT

You can use BERT models in two ways:

Train a new BERT model

Use your own data to pretrain a BERT model.

from transformers_keras import BertForPretrainingModel

model_config = {
    'max_positions': 128,
    'num_layers': 6,
    'vocab_size': 21128,
}

model = BertForPretrainingModel(**model_config)

Load a pretrained BERT model

from transformers_keras import BertForPretrainingModel

# download the pretrained model and extract it to some path
PRETRAINED_BERT_MODEL = '/path/to/chinese_L-12_H-768_A-12'

model = BertForPretrainingModel.from_pretrained(PRETRAINED_BERT_MODEL)

After building the model, you can train the model with your own data.

Here is an example:

from transformers_keras import BertTFRecordDatasetBuilder

builder = BertTFRecordDatasetBuilder(max_sequence_length=128, record_option='GZIP')

loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy(name='acc')
model.compile(optimizer='adam', loss=loss, metrics=[metric])
model(model.dummy_inputs())
model.summary()

train_files = ['testdata/bert_custom_pretrain.tfrecord']
train_dataset = builder.build_train_dataset(train_files, batch_size=32)
model.fit(train_dataset, epochs=2)

ALBERT

You can use ALBERT model in two ways:

Train a new ALBERT model

You should process your data to tfrecord format. Modify this script transformers_keras/utils/bert_tfrecord_custom_generator.py as you need.

from transformers_keras import AlbertForPretrainingModel

# ALBERT has the same data format with BERT
dataset_builder = BertTFRecordDatasetBuilder(
    max_sequence_length=128, record_option='GZIP', train_repeat_count=100, eos_token='T')

model_config = {
    'max_positions': 128,
    'num_layers': 6,
    'num_groups': 1,
    'num_layers_each_group': 1,
    'vocab_size': 21128,
}

model = AlbertForPretrainingModel(**model_config)

Load a pretrained ALBERT model

from transformers_keras import AlbertForPretrainingModel

# download the pretrained model and extract it to some path
PRETRAINED_BERT_MODEL = '/path/to/zh_albert_large'

model = AlbertForPretrainingModel.from_pretrained(PRETRAINED_BERT_MODEL)

After building the model, you can train this model with your own data.

Here is an example:

from transformers_keras import BertTFRecordDatasetBuilder

builder = BertTFRecordDatasetBuilder(max_sequence_length=128, record_option='GZIP')

loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy(name='acc')
model.compile(optimizer='adam', loss=loss, metrics=[metric])
model(model.dummy_inputs())
model.summary()

train_files = ['testdata/bert_custom_pretrain.tfrecord']
train_dataset = builder.build_train_dataset(train_files, batch_size=32)
model.fit(train_dataset, epochs=2)

About

Transformer-based models implemented in tensorflow 2.x(using keras).

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Python 100.0%