PyTorch Frame is a tabular deep learning extension library for PyTorch. Modern data is stored in a table format with heterogeneous columns with different semantic types, e.g. numerical (age, price), categorical (gender, product type), time, text(descriptions), images(pictures), etc. The goal of PyTorch Frame is to build a deep learning framework to perform effective machine learning on such complex data.
PyTorch Frame allow existing (and future) methods to be easily and intuitively implemented in a modular way. The library includes various methods for deep learning on tables from a variety of published papers. In addition, it includes easy-to-use mini-batch loaders, a large number of common benchmark datasets, and intuitive interfaces for custom dataset integration.
With PyTorch Frame, we aim to democratize the deep learning research for tabular data. Whether you're an experienced deep learning researcher, a novice delving into machine learning, or a Kaggle enthusiast, PyTorch Frame makes experimenting with different architectures a breeze.
Our aspirations for PyTorch Frame are twofold:
-
To Propel Deep Learning Research for Tabular Data: Historically, tree-based models have superior performance on tabular datasets. However, tree-based models have many limitations, for example, they cannot be trained with downstream models like GNNs, RNNs and Transformers, hence hard to be integrated into larger systems. Tree-based models also cannot handle diverse column types, like text or sequences. Recent research shows that some deep learning models have comparable, if not better, performance on larger datasets. This is not to mention the advantages in training efficiency with massive data scales.
-
To Support Enhanced Semantic Types and Model Architectures: We aim to extend PyTorch Frame's functionalities to handle a wider variety of semantic types, such as time sequences. Concurrently, we're focusing on extending PyTorch Frame to latest technologies like large language models.
PyTorch Frame emphasizes a tensor-centric API and maintains design elements similar to vanilla PyTorch. For those acquainted with PyTorch, adapting to PyTorch Frame is a seamless process. Here are our major library highlights:
- Versitility with Tabular Data: PyTorch Frame provides in-house support for multimodal learning on a variety of semantic types, including categories, numbers and texts. Extensions to other types, including but not limited to sequences, multicategoricies, images, time are on our road map.
- Modular Implementation of Diverse Models: We provide a framework for users to implement deep learning models in a modular way, enhancing module reusability, code clarity and ease of experimentation. See next section for more details.
- Empowerment through Multimodal Learning: PyTorch Frame can mesh with a variety of different transformers on different semantic types, e.g. large language models on text, as illustrated in this example.
- Pytorch Integration: PyTorch Frame synergizes seamlessly with other Pytorch libraries, like PyG, enabling end-to-end training of Pytorch Frame models with any other Pytorch models.
Models in PyTorch Frame follow a modular design of FeatureEncoder
, TableConv
, and Decoder
, as shown in the figure below:
In essence, this modular setup empowers users to effortlessly experiment with myriad architectures:
Materialization
handles converting the raw pandasDataFrame
into aTensorFrame
that is amenable to Pytorch-based training and modeling.FeatureEncoder
encodes different semantic types into hidden embeddings.TableConv
handles column-wise interactions between different semantic types.Decoder
summarizes the embeddings and generates the prediction outputs.
In this quick tour, we showcase the ease of creating and training a deep tabular model with only a few lines of code.
In the first example, we implement a simple ExampleTransformer
following the modular architecture of Pytorch Frame. A model maps TensorFrame
into embeddings. We decompose ExampleTransformer
, and most other models in Pytorch Frame into three modular components.
self.encoder
: The encoder maps an inputtensor
of size[batch_size, num_cols]
to an embedding of size[batch_size, num_cols, channels]
. To handle input of different semantic types, we useStypeWiseFeatureEncoder
where users can specify different encoders using a dictionary. In this example, we useEmbeddingEncoder
for categorical features andLinearEncoder
for numerical features--they are both built-in encoders in Pytorch Frame. For a comprehensive list of encoders, check out this file.self.convs
: We create a two layers ofTabTransformerConv
. EachTabTransformerConv
module transforms an embedding of size[batch_size, num_cols, channels]
and into an embedding of the same size.self.decoder
: We use a mean-based decoder that maps the dimension of the embedding back from[batch_size, num_cols, channels]
to[batch_size, out_channels]
.
from typing import Any, Dict, List
from torch import Tensor
from torch.nn import Linear, Module, ModuleList
import torch_frame
from torch_frame import TensorFrame, stype
from torch_frame.data.stats import StatType
from torch_frame.nn.conv import TabTransformerConv
from torch_frame.nn.encoder import (
EmbeddingEncoder,
LinearEncoder,
StypeWiseFeatureEncoder,
)
class ExampleTransformer(Module):
def __init__(
self,
channels: int,
out_channels: int,
num_layers: int,
num_heads: int,
col_stats: Dict[str, Dict[StatType, Any]],
col_names_dict: Dict[torch_frame.stype, List[str]],
):
super().__init__()
self.encoder = StypeWiseFeatureEncoder(
out_channels=channels,
col_stats=col_stats,
col_names_dict=col_names_dict,
stype_encoder_dict={
stype.categorical: EmbeddingEncoder(),
stype.numerical: LinearEncoder()
},
)
self.tab_transformer_convs = ModuleList([
TabTransformerConv(
channels=channels,
num_heads=num_heads,
) for _ in range(num_layers)
])
self.decoder = Linear(channels, out_channels)
def forward(self, tf: TensorFrame) -> Tensor:
x, _ = self.encoder(tf)
for tab_transformer_conv in self.tab_transformer_convs:
x = tab_transformer_conv(x)
out = self.decoder(x.mean(dim=1))
return out
Once we decide the model, we can load the Adult Census Income dataset and create a train dataloader.
from torch_frame.datasets import Yandex
dataset = Yandex(root='/tmp/adult', name='adult')
dataset.materialize()
train_dataset = dataset[:0.8]
train_loader = DataLoader(train_dataset.tensor_frame, batch_size=128,
shuffle=True)
We can now optimize the model in a training loop, similar to the standard PyTorch training procedure.
import torch
import torch.nn.functional as F
from tqdm import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ExampleTransformer(
channels=32,
out_channels=dataset.num_classes,
num_layers=2,
num_heads=8,
col_stats=train_dataset.col_stats,
col_names_dict=train_dataset.tensor_frame.col_names_dict,
).to(device)
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(50):
for tf in tqdm(train_loader):
pred = model.forward(tf)
loss = F.cross_entropy(pred, tf.y)
optimizer.zero_grad()
loss.backward()
We list currently supported deep tabular models:
- Trompt from Chen et al.: Trompt: Towards a Better Deep Neural Network for Tabular Data (ICML 2023) [Example]
- FTTransformer from Gorishniy et al.: Revisiting Deep Learning Models for Tabular Data (NeurIPS 2021) [Example]
- ResNet from Gorishniy et al.: Revisiting Deep Learning Models for Tabular Data (NeurIPS 2021) [Example]
- TabNet from Arık et al.: TabNet: Attentive Interpretable Tabular Learning (AAAI 2021) [Example]
- ExcelFormer from Chen et al.: ExcelFormer: A Neural Network Surpassing GBDTs on Tabular Data [Example]
- TabTransformer from Huang et al.: TabTransformer: Tabular Data Modeling Using Contextual Embeddings [Example]
In addition, we implemented XGBoost
and CatBoost
examples with hyperparameter-tuning using Optuna for users who'd like to compare their model performance with GBDTs
.
PyTorch Frame is available for Python 3.8 to Python 3.11.
pip install pytorch_frame
See the installation guide for other options.