Replies: 1 comment
-
Got it! It was just a problem of resetting the computer. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi all,
I'm trying to solve a binary classification problem and so it seems to me that binary cross entropy loss function in tensorflow is the one I must consider. However, this loss isn't included in the file losses.py of deepxde.
I've included it myself in losses.py. Precisely, here is the piece of code:
def binary_cross_entropy(y_true, y_pred):
# TODO: pytorch
return tf.keras.losses.BinaryCrossentropy(from_logits=True)(y_true, y_pred)
and in the LOSS_DICT
Unfortunately, it doesn't work. I get an error message.
May you please help me? thanks in advance.
Here is the error message:
KeyError Traceback (most recent call last)
Cell In[74], line 18
9 net = dde.nn.DeepONet(
10 [m, 40, p], # dimensions of the fully connected branch net
11 [n, 40, p], # dimensions of the fully connected trunk net
12 "sigmoid",
13 "Glorot normal", # initialization of parameters
14 )
16 model = dde.Model(data, net)
---> 18 model.compile("adam", lr=0.001, loss="binary cross entropy", metrics=['accuracy']) # accuracy is the mean of matches between predictions and labels
19 model.train(iterations=ITERATIONS)
20 model.compile("L-BFGS-B", metrics=['accuracy'])
File c:\Users\Paco\anaconda3\envs\deeponetcontrol\lib\site-packages\deepxde\utils\internal.py:22, in timing..wrapper(*args, **kwargs)
19 @wraps(f)
20 def wrapper(*args, **kwargs):
21 ts = timeit.default_timer()
---> 22 result = f(*args, **kwargs)
23 te = timeit.default_timer()
24 if config.rank == 0:
File c:\Users\Paco\anaconda3\envs\deeponetcontrol\lib\site-packages\deepxde\model.py:121, in Model.compile(self, optimizer, lr, loss, metrics, decay, loss_weights, external_trainable_variables)
119 print("Compiling model...")
120 self.opt_name = optimizer
--> 121 loss_fn = losses_module.get(loss)
122 self.loss_weights = loss_weights
123 if external_trainable_variables is None:
File c:\Users\Paco\anaconda3\envs\deeponetcontrol\lib\site-packages\deepxde\losses.py:69, in get(identifier)
66 return list(map(get, identifier))
68 if isinstance(identifier, str):
---> 69 return LOSS_DICT[identifier]
70 if callable(identifier):
71 return identifier
KeyError: 'binary cross entropy'
Beta Was this translation helpful? Give feedback.
All reactions