Skip to content

Commit

Permalink
fix(tf2_gnn/models/*): let __init__ pass extra kwargs to super()
Browse files Browse the repository at this point in the history
  • Loading branch information
mmjb committed Mar 11, 2021
1 parent 9025691 commit 16153fa
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 9 deletions.
3 changes: 0 additions & 3 deletions tf2_gnn/models/graph_binary_classification_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ def get_default_hyperparameters(
super_params.update(these_hypers)
return super_params

def __init__(self, params: Dict[str, Any], dataset: GraphDataset, name: str = None):
super().__init__(params, dataset=dataset, name=name)

def compute_task_output(
self,
batch_features: Dict[str, tf.Tensor],
Expand Down
4 changes: 2 additions & 2 deletions tf2_gnn/models/graph_regression_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def get_default_hyperparameters(
super_params.update(these_hypers)
return super_params

def __init__(self, params: Dict[str, Any], dataset: GraphDataset, name: str = None):
super().__init__(params, dataset=dataset, name=name)
def __init__(self, params: Dict[str, Any], dataset: GraphDataset, name: str = None, **kwargs):
super().__init__(params, dataset=dataset, name=name, **kwargs)
self._node_to_graph_aggregation = None

# Construct sublayers:
Expand Down
4 changes: 2 additions & 2 deletions tf2_gnn/models/node_multiclass_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def get_default_hyperparameters(cls, mp_style: Optional[str] = None) -> Dict[str
super_params.update(these_hypers)
return super_params

def __init__(self, params: Dict[str, Any], dataset: GraphDataset, name: str = None):
super().__init__(params, dataset=dataset, name=name)
def __init__(self, params: Dict[str, Any], dataset: GraphDataset, name: str = None, **kwargs):
super().__init__(params, dataset=dataset, name=name, **kwargs)
if not hasattr(dataset, "num_node_target_labels"):
raise ValueError(f"Provided dataset of type {type(dataset)} does not provide num_node_target_labels information.")
self._num_labels = dataset.num_node_target_labels
Expand Down
4 changes: 2 additions & 2 deletions tf2_gnn/models/qm9_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def get_default_hyperparameters(
super_params.update(these_hypers)
return super_params

def __init__(self, params: Dict[str, Any], dataset: GraphDataset, name: str = None):
super().__init__(params, dataset=dataset, name=name)
def __init__(self, params: Dict[str, Any], dataset: GraphDataset, name: str = None, **kwargs):
super().__init__(params, dataset=dataset, name=name, **kwargs)
assert isinstance(dataset, QM9Dataset)

self._task_id = int(dataset._params["task_id"])
Expand Down

0 comments on commit 16153fa

Please sign in to comment.