Skip to content

Commit

Permalink
Changed PhysioNet loss to BCE & added Random Fourier Networks to the …
Browse files Browse the repository at this point in the history
…function fitting experiment.
  • Loading branch information
dwromero committed Nov 13, 2021
1 parent 891bc32 commit 0d5d192
Show file tree
Hide file tree
Showing 22 changed files with 95 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[flake8]
ignore = E203, E266, E501, W503, F403, F401
max-line-length = 100
max-complexity = 22
max-complexity = 23
select = B,C,E,F,W,T4,B9
Binary file modified __pycache__/config.cpython-37.pyc
Binary file not shown.
Binary file modified __pycache__/model.cpython-37.pyc
Binary file not shown.
Binary file modified __pycache__/tester.cpython-37.pyc
Binary file not shown.
Binary file modified __pycache__/trainer.cpython-37.pyc
Binary file not shown.
2 changes: 1 addition & 1 deletion ckconv/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .activation_functions import Swish, Sine
from .linear import Linear1d, Linear2d
from .norm import LayerNorm
from .ckconv import CKConv, KernelNet
from .ckconv import CKConv, KernelNet, RFNet
from .ck_block import CKBlock
from .loss import LnLoss
from .conv import CausalConv1d
Binary file modified ckconv/nn/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file modified ckconv/nn/__pycache__/ckconv.cpython-37.pyc
Binary file not shown.
63 changes: 63 additions & 0 deletions ckconv/nn/ckconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,69 @@
import numpy as np
import ckconv.nn.functional as ckconv_f
from torch.nn.utils import weight_norm
import math


class RFNet(torch.nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
hidden_channels: int,
activation_function: str,
norm_type: str,
dim_linear: int,
bias: bool,
omega_0: float,
weight_dropout: float,
):
super().__init__()

ActivationFunction = torch.nn.ReLU
Linear = {1: ckconv.nn.Linear1d, 2: ckconv.nn.Linear2d}[dim_linear]

self.kernel_net = torch.nn.Sequential(
InputMapping(
in_channels,
hidden_channels // 2,
omega_0=omega_0,
bias=True,
),
Linear(hidden_channels, hidden_channels, bias=bias),
ActivationFunction(),
Linear(hidden_channels, out_channels, bias=bias),
)

def forward(self, x):
return self.kernel_net(x)


class InputMapping(torch.nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
omega_0: float,
stride: int = 1,
bias: bool = True,
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=stride,
padding=0,
bias=bias,
)
self.omega_0 = omega_0

# Initialize:
self.weight.data.normal_(0.0, 2 * math.pi * self.omega_0)

def forward(self, x):
out = super().forward(x)
out = torch.cat([torch.cos(out), torch.sin(out)], dim=1)
return out


class KernelNet(torch.nn.Module):
Expand Down
Binary file modified ckernel_fitting/__pycache__/config.cpython-37.pyc
Binary file not shown.
Binary file modified ckernel_fitting/__pycache__/functions.cpython-37.pyc
Binary file not shown.
1 change: 1 addition & 0 deletions ckernel_fitting/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def get_config():
# Parameters of SIREN
kernelnet_omega_0=0.0,
# If model == CKCNN, kernelnet_activation_function==Sine, the value of the omega_0 parameter, e.g., 30.
kernelnet_type="",
comment="",
# An additional comment to be added to the config.path parameter specifying where
# the network parameters will be saved / loaded from.
Expand Down
12 changes: 9 additions & 3 deletions ckernel_fitting/fit_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,13 @@ def main(_):
np.random.seed(config.seed)

# initialize weight and bias
os.environ["WANDB_API_KEY"] = "3fe624d6a1979f80f1277200966d17bed042ec31"
wandb.init(
project="ckconv",
config=copy.deepcopy(dict(config)),
group="kernelfit_{}".format(config.function),
entity="vu_uva_team",
tags=["kernelfit", config.function],
save_code=True,
# save_code=True,
# job_type=config.function,
)

Expand Down Expand Up @@ -152,7 +151,14 @@ def main(_):

def get_model(config):
# Load the model: The model is always equal to a continuous kernel
model = ckconv.nn.KernelNet(
if config.kernelnet_type == "SIREN":
model_type = ckconv.nn.KernelNet
elif config.kernelnet_type == "RFNet":
model_type = ckconv.nn.RFNet
else:
raise NotImplementedError(f"{config.kernelnet_type}")

model = model_type(
in_channels=1,
out_channels=1,
hidden_channels=config.kernelnet_no_hidden,
Expand Down
2 changes: 1 addition & 1 deletion ckernel_fitting/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _gaussian(config, x):
def _constant(config, x):
# apply function
f = np.ones_like(x)
# f[:int(len(f)/2)] = -1.0
f[int(len(f) / 2) :] = -1.0
# return
return f

Expand Down
2 changes: 1 addition & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def get_model(config):
),
"PhysioNet_CKCNN": lambda: models.seqImg_CKCNN(
in_channels=in_channels,
out_channels=2,
out_channels=1,
hidden_channels=config.no_hidden,
num_blocks=config.no_blocks,
kernelnet_hidden_channels=config.kernelnet_no_hidden,
Expand Down
Binary file modified models/__pycache__/ckcnn.cpython-37.pyc
Binary file not shown.
2 changes: 2 additions & 0 deletions models/ckcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,6 @@ def __init__(
def forward(self, x):
out = self.backbone(x)
out = self.finallyr(out[:, :, -1])
if out.shape[-1] == 1:
out = out.squeeze(-1)
return out
1 change: 0 additions & 1 deletion run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def main(_):
np.random.seed(config.seed)

# initialize weight and bias
os.environ["WANDB_API_KEY"] = "" ## Place here your API key.
if not config.train:
os.environ["WANDB_MODE"] = "dryrun"
tags = [
Expand Down
Binary file not shown.
Binary file not shown.
11 changes: 8 additions & 3 deletions tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,20 @@ def _test_classif(model, test_loader, config):
inputs = inputs[:, :, permutation]

outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)

if len(outputs.shape) == 1:
labels = labels.float()
preds = (torch.sigmoid(outputs) > 0.5).int()
else:
_, preds = torch.max(outputs, 1)

total += labels.size(0)
correct += (predicted == labels).sum().item()
correct += (preds == labels).sum().item()

# Save for AUC
if config.report_auc:
true_y_cpus.append(labels.detach().cpu())
pred_y_cpus.append(predicted.detach().cpu())
pred_y_cpus.append(outputs.detach().cpu())

# Print results
test_acc = correct / total
Expand Down
12 changes: 8 additions & 4 deletions trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def train(model, dataloaders, config, test_loader):
"CIFAR10": torch.nn.CrossEntropyLoss(),
"SpeechCommands": torch.nn.CrossEntropyLoss(),
"CharTrajectories": torch.nn.CrossEntropyLoss(),
"PhysioNet": torch.nn.CrossEntropyLoss(),
"PhysioNet": torch.nn.BCEWithLogitsLoss(),
}[config.dataset]

train_function = {
Expand Down Expand Up @@ -189,17 +189,21 @@ def _train_classif(
inputs = torch.dropout(inputs, config.dropout_in, train)
outputs = model(inputs)

if len(outputs.shape) == 1:
labels = labels.float()
preds = (torch.sigmoid(outputs) > 0.5).int()
else:
_, preds = torch.max(outputs, 1)

loss = criterion(outputs, labels)
# Regularization:
if config.weight_decay != 0.0:
loss = loss + weight_regularizer(model)

_, preds = torch.max(outputs, 1)

# Save for AUC
if config.report_auc:
true_y_cpus.append(labels.detach().cpu())
pred_y_cpus.append(preds.detach().cpu())
pred_y_cpus.append(outputs.detach().cpu())

# BwrdPhase:
if phase == "train":
Expand Down

0 comments on commit 0d5d192

Please sign in to comment.