Skip to content

Commit

Permalink
Added run & checkpoint for PhysioNet
Browse files Browse the repository at this point in the history
  • Loading branch information
dwromero committed Nov 29, 2021
1 parent 5c3e778 commit 6a1e0a6
Show file tree
Hide file tree
Showing 11 changed files with 46 additions and 15 deletions.
Binary file modified __pycache__/tester.cpython-37.pyc
Binary file not shown.
Binary file modified __pycache__/trainer.cpython-37.pyc
Binary file not shown.
Binary file modified ckconv/nn/__pycache__/ckconv.cpython-37.pyc
Binary file not shown.
35 changes: 24 additions & 11 deletions ckconv/nn/ckconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,17 +243,30 @@ def __init__(
:param weight_dropout: Dropout rate applied to the sampled convolutional kernels.
"""
super().__init__()
self.Kernel = KernelNet(
dim_linear,
out_channels * in_channels,
hidden_channels,
activation_function,
norm_type,
dim_linear,
bias,
omega_0,
weight_dropout,
)
if activation_function == "RandomFourier":
self.Kernel = RFNet(
in_channels=dim_linear,
out_channels=out_channels * in_channels,
hidden_channels=hidden_channels,
activation_function="ReLU",
norm_type=norm_type,
dim_linear=dim_linear,
bias=bias,
omega_0=omega_0,
weight_dropout=weight_dropout,
)
else:
self.Kernel = KernelNet(
in_channels=dim_linear,
out_channels=out_channels * in_channels,
hidden_channels=hidden_channels,
activation_function=activation_function,
norm_type=norm_type,
dim_linear=dim_linear,
bias=bias,
omega_0=omega_0,
weight_dropout=weight_dropout,
)

if bias:
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
Expand Down
Binary file modified ckernel_fitting/__pycache__/config.cpython-37.pyc
Binary file not shown.
Binary file modified ckernel_fitting/__pycache__/functions.cpython-37.pyc
Binary file not shown.
5 changes: 5 additions & 0 deletions ckernel_fitting/fit_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# project
import ckconv.nn
from ckernel_fitting.functions import get_function_to_fit
from ckconv.utils import num_params


FLAGS = flags.FLAGS
Expand Down Expand Up @@ -177,6 +178,10 @@ def get_model(config):
model.to(config.device)
torch.backends.cudnn.benchmark = True

no_params = num_params(model)
print("Number of parameters:", no_params)
wandb.run.summary["no_params"] = no_params

return model


Expand Down
15 changes: 12 additions & 3 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,14 @@ def dataset_constructor(
"SpeechCommands": SpeechCommands,
"CharTrajectories": CharTrajectories,
"PhysioNet": PhysioNet,
'PennTreeBankChar': PennTreeBankChar,
"PennTreeBankChar": PennTreeBankChar,
}[config.dataset]
if config.dataset == 'PennTreeBankChar':

if config.dataset == "PennTreeBankChar":
eval_batch_size = 10
else:
eval_batch_size = config.batch_size

training_set = dataset(
partition="train",
seq_length=config.seq_length,
Expand All @@ -57,7 +61,12 @@ def dataset_constructor(
valid_seq_len=config.valid_seq_len,
batch_size=eval_batch_size,
)
if config.dataset in ["SpeechCommands", "CharTrajectories", "PhysioNet", "PennTreeBankChar"]:
if config.dataset in [
"SpeechCommands",
"CharTrajectories",
"PhysioNet",
"PennTreeBankChar",
]:
validation_set = dataset(
partition="val",
seq_length=config.seq_length,
Expand Down
Binary file modified datasets/__pycache__/physionet.cpython-37.pyc
Binary file not shown.
6 changes: 5 additions & 1 deletion runs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,8 @@ One can change the frequency of the test set by varying the key `--config.sr_tes

###### Data drop percentage = 70%

`run_experiment.py --config.batch_size=32 --config.clip=0 --config.dataset=SpeechCommands --config.device=cuda --config.drop_rate=70 --config.dropout=0.1 --config.dropout_in=0.1 --config.epochs=300 --config.kernelnet_activation_function=Sine --config.kernelnet_no_hidden=32 --config.kernelnet_norm_type=LayerNorm --config.kernelnet_omega_0=23.29255834358289 --config.lr=0.001 --config.mfcc=False --config.model=CKCNN --config.no_blocks=2 --config.no_hidden=30 --config.optimizer=Adam --config.permuted=False --config.sched_decay_factor=5 --config.sched_decay_steps=(75,) --config.sched_patience=20 --config.scheduler=plateau --config.sr_train=1 --config.weight_decay=0 --config.weight_dropout=0.1`
`run_experiment.py --config.batch_size=32 --config.clip=0 --config.dataset=SpeechCommands --config.device=cuda --config.drop_rate=70 --config.dropout=0.1 --config.dropout_in=0.1 --config.epochs=300 --config.kernelnet_activation_function=Sine --config.kernelnet_no_hidden=32 --config.kernelnet_norm_type=LayerNorm --config.kernelnet_omega_0=23.29255834358289 --config.lr=0.001 --config.mfcc=False --config.model=CKCNN --config.no_blocks=2 --config.no_hidden=30 --config.optimizer=Adam --config.permuted=False --config.sched_decay_factor=5 --config.sched_decay_steps=(75,) --config.sched_patience=20 --config.scheduler=plateau --config.sr_train=1 --config.weight_decay=0 --config.weight_dropout=0.1`

### PhysioNet

`run_experiment.py --config.batch_size=1024 --config.clip=0 --config.dataset=PhysioNet --config.device=cuda --config.dropout=0.1 --config.dropout_in=0.1 --config.epochs=200 --config.kernelnet_activation_function=Sine --config.kernelnet_no_hidden=32 --config.kernelnet_norm_type=LayerNorm --config.kernelnet_omega_0=9.779406396796968 --config.lr=0.001 --config.model=CKCNN --config.no_blocks=2 --config.no_hidden=30 --config.optimizer=Adam --config.permuted=False --config.report_auc=True --config.sched_decay_factor=5 --config.sched_patience=20 --config.scheduler=plateau --config.weight_decay=0.0001 --config.weight_dropout=0`
Binary file not shown.

0 comments on commit 6a1e0a6

Please sign in to comment.