-
Notifications
You must be signed in to change notification settings - Fork 270
/
Copy pathtest_foundations.py
500 lines (458 loc) · 16.9 KB
/
test_foundations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
from pathlib import Path
import numpy as np
import pytest
import torch
import torch.nn.functional
from ase.build import molecule
from e3nn import o3
from e3nn.util import jit
from scipy.spatial.transform import Rotation as R
from mace import data, modules, tools
from mace.calculators import mace_mp, mace_off
from mace.tools import torch_geometric
from mace.tools.finetuning_utils import load_foundations_elements
from mace.tools.scripts_utils import extract_config_mace_model, remove_pt_head
from mace.tools.utils import AtomicNumberTable
MODEL_PATH = (
Path(__file__).parent.parent
/ "mace"
/ "calculators"
/ "foundations_models"
/ "2023-12-03-mace-mp.model"
)
torch.set_default_dtype(torch.float64)
config = data.Configuration(
atomic_numbers=molecule("H2COH").numbers,
positions=molecule("H2COH").positions,
properties={
"forces": molecule("H2COH").positions,
"energy": -1.5,
"charges": molecule("H2COH").numbers,
"dipole": np.array([-1.5, 1.5, 2.0]),
},
property_weights={
"forces": 1.0,
"energy": 1.0,
"charges": 1.0,
"dipole": 1.0,
},
)
# Created the rotated environment
rot = R.from_euler("z", 60, degrees=True).as_matrix()
positions_rotated = np.array(rot @ config.positions.T).T
config_rotated = data.Configuration(
atomic_numbers=molecule("H2COH").numbers,
positions=positions_rotated,
properties={
"forces": molecule("H2COH").positions,
"energy": -1.5,
"charges": molecule("H2COH").numbers,
"dipole": np.array([-1.5, 1.5, 2.0]),
},
property_weights={
"forces": 1.0,
"energy": 1.0,
"charges": 1.0,
"dipole": 1.0,
},
)
table = tools.AtomicNumberTable([1, 6, 8])
atomic_energies = np.array([0.0, 0.0, 0.0], dtype=float)
# @pytest.skip("Problem with the float type", allow_module_level=True)
def test_foundations():
# Create MACE model
model_config = dict(
r_max=6,
num_bessel=10,
num_polynomial_cutoff=5,
max_ell=3,
interaction_cls=modules.interaction_classes[
"RealAgnosticResidualInteractionBlock"
],
interaction_cls_first=modules.interaction_classes[
"RealAgnosticResidualInteractionBlock"
],
num_interactions=2,
num_elements=3,
hidden_irreps=o3.Irreps("128x0e"),
MLP_irreps=o3.Irreps("16x0e"),
gate=torch.nn.functional.silu,
atomic_energies=atomic_energies,
avg_num_neighbors=3,
atomic_numbers=table.zs,
correlation=3,
radial_type="bessel",
atomic_inter_scale=0.1,
atomic_inter_shift=0.0,
)
model = modules.ScaleShiftMACE(**model_config)
calc = mace_mp(
model="small",
device="cpu",
default_dtype="float64",
)
model_foundations = calc.models[0]
model_loaded = load_foundations_elements(
model,
model_foundations,
table=table,
load_readout=True,
use_shift=False,
max_L=0,
)
atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=6.0)
atomic_data2 = data.AtomicData.from_config(
config_rotated, z_table=table, cutoff=6.0
)
data_loader = torch_geometric.dataloader.DataLoader(
dataset=[atomic_data, atomic_data2],
batch_size=2,
shuffle=True,
drop_last=False,
)
batch = next(iter(data_loader))
forces_loaded = model_loaded(batch.to_dict())["forces"]
forces = model(batch.to_dict())["forces"]
assert torch.allclose(forces, forces_loaded)
def test_multi_reference():
config_multi = data.Configuration(
atomic_numbers=molecule("H2COH").numbers,
positions=molecule("H2COH").positions,
properties={
"forces": molecule("H2COH").positions,
"energy": -1.5,
"charges": molecule("H2COH").numbers,
"dipole": np.array([-1.5, 1.5, 2.0]),
},
property_weights={
"forces": 1.0,
"energy": 1.0,
"charges": 1.0,
"dipole": 1.0,
},
head="MP2",
)
table_multi = tools.AtomicNumberTable([1, 6, 8])
atomic_energies_multi = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=float)
# Create MACE model
model_config = dict(
r_max=6,
num_bessel=10,
num_polynomial_cutoff=5,
max_ell=3,
interaction_cls=modules.interaction_classes[
"RealAgnosticResidualInteractionBlock"
],
interaction_cls_first=modules.interaction_classes[
"RealAgnosticResidualInteractionBlock"
],
num_interactions=2,
num_elements=3,
hidden_irreps=o3.Irreps("128x0e + 128x1o"),
MLP_irreps=o3.Irreps("16x0e"),
gate=torch.nn.functional.silu,
atomic_energies=atomic_energies_multi,
avg_num_neighbors=61,
atomic_numbers=table.zs,
correlation=3,
radial_type="bessel",
atomic_inter_scale=[1.0, 1.0],
atomic_inter_shift=[0.0, 0.0],
heads=["MP2", "DFT"],
)
model = modules.ScaleShiftMACE(**model_config)
calc_foundation = mace_mp(model="medium", device="cpu", default_dtype="float64")
model_loaded = load_foundations_elements(
model,
calc_foundation.models[0],
table=table,
load_readout=True,
use_shift=False,
max_L=1,
)
atomic_data = data.AtomicData.from_config(
config_multi, z_table=table_multi, cutoff=6.0, heads=["MP2", "DFT"]
)
data_loader = torch_geometric.dataloader.DataLoader(
dataset=[atomic_data, atomic_data],
batch_size=2,
shuffle=True,
drop_last=False,
)
batch = next(iter(data_loader))
forces_loaded = model_loaded(batch.to_dict())["forces"]
calc_foundation = mace_mp(model="medium", device="cpu", default_dtype="float64")
atoms = molecule("H2COH")
atoms.info["head"] = "MP2"
atoms.calc = calc_foundation
forces = atoms.get_forces()
assert np.allclose(
forces, forces_loaded.detach().numpy()[:5, :], atol=1e-5, rtol=1e-5
)
@pytest.mark.parametrize(
"calc",
[
mace_mp(device="cpu", default_dtype="float64"),
mace_mp(model="small", device="cpu", default_dtype="float64"),
mace_mp(model="medium", device="cpu", default_dtype="float64"),
mace_mp(model="large", device="cpu", default_dtype="float64"),
mace_mp(model=MODEL_PATH, device="cpu", default_dtype="float64"),
mace_off(model="small", device="cpu", default_dtype="float64"),
mace_off(model="medium", device="cpu", default_dtype="float64"),
mace_off(model="large", device="cpu", default_dtype="float64"),
mace_off(model=MODEL_PATH, device="cpu", default_dtype="float64"),
],
)
def test_compile_foundation(calc):
model = calc.models[0]
atoms = molecule("CH4")
atoms.positions += np.random.randn(*atoms.positions.shape) * 0.1
batch = calc._atoms_to_batch(atoms) # pylint: disable=protected-access
output_1 = model(batch.to_dict())
model_compiled = jit.compile(model)
output = model_compiled(batch.to_dict())
for key in output_1.keys():
if isinstance(output_1[key], torch.Tensor):
assert torch.allclose(output_1[key], output[key], atol=1e-5)
@pytest.mark.parametrize(
"model",
[
mace_mp(model="small", device="cpu", default_dtype="float64").models[0],
mace_mp(model="medium", device="cpu", default_dtype="float64").models[0],
mace_mp(model="large", device="cpu", default_dtype="float64").models[0],
mace_mp(model=MODEL_PATH, device="cpu", default_dtype="float64").models[0],
mace_off(model="small", device="cpu", default_dtype="float64").models[0],
mace_off(model="medium", device="cpu", default_dtype="float64").models[0],
mace_off(model="large", device="cpu", default_dtype="float64").models[0],
mace_off(model=MODEL_PATH, device="cpu", default_dtype="float64").models[0],
],
)
def test_extract_config(model):
assert isinstance(model, modules.ScaleShiftMACE)
model_copy = modules.ScaleShiftMACE(**extract_config_mace_model(model))
model_copy.load_state_dict(model.state_dict())
z_table = AtomicNumberTable([int(z) for z in model.atomic_numbers])
atomic_data = data.AtomicData.from_config(config, z_table=z_table, cutoff=6.0)
data_loader = torch_geometric.dataloader.DataLoader(
dataset=[atomic_data, atomic_data],
batch_size=2,
shuffle=True,
drop_last=False,
)
batch = next(iter(data_loader))
output = model(batch.to_dict())
output_copy = model_copy(batch.to_dict())
# assert all items of the output dicts are equal
for key in output.keys():
if isinstance(output[key], torch.Tensor):
assert torch.allclose(output[key], output_copy[key], atol=1e-5)
def test_remove_pt_head():
# Set up test data
torch.manual_seed(42)
atomic_energies_pt_head = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=float)
z_table = AtomicNumberTable([1, 8]) # H and O
# Create multihead model
model_config = {
"r_max": 5.0,
"num_bessel": 8,
"num_polynomial_cutoff": 5,
"max_ell": 2,
"interaction_cls": modules.interaction_classes[
"RealAgnosticResidualInteractionBlock"
],
"interaction_cls_first": modules.interaction_classes[
"RealAgnosticResidualInteractionBlock"
],
"num_interactions": 2,
"num_elements": len(z_table),
"hidden_irreps": o3.Irreps("32x0e + 32x1o"),
"MLP_irreps": o3.Irreps("16x0e"),
"gate": torch.nn.functional.silu,
"atomic_energies": atomic_energies_pt_head,
"avg_num_neighbors": 8,
"atomic_numbers": z_table.zs,
"correlation": 3,
"heads": ["pt_head", "DFT"],
"atomic_inter_scale": [1.0, 1.0],
"atomic_inter_shift": [0.0, 0.1],
}
model = modules.ScaleShiftMACE(**model_config)
# Create test molecule
mol = molecule("H2O")
config_pt_head = data.Configuration(
atomic_numbers=mol.numbers,
positions=mol.positions,
properties={"energy": 1.0, "forces": np.random.randn(len(mol), 3)},
property_weights={"forces": 1.0, "energy": 1.0},
head="DFT",
)
atomic_data = data.AtomicData.from_config(
config_pt_head, z_table=z_table, cutoff=5.0, heads=["pt_head", "DFT"]
)
dataloader = torch_geometric.dataloader.DataLoader(
dataset=[atomic_data], batch_size=1, shuffle=False
)
batch = next(iter(dataloader))
# Test original mode
output_orig = model(batch.to_dict())
# Convert to single head model
new_model = remove_pt_head(model, head_to_keep="DFT")
# Basic structure tests
assert len(new_model.heads) == 1
assert new_model.heads[0] == "DFT"
assert new_model.atomic_energies_fn.atomic_energies.shape[0] == 1
assert len(torch.atleast_1d(new_model.scale_shift.scale)) == 1
assert len(torch.atleast_1d(new_model.scale_shift.shift)) == 1
# Test output consistency
atomic_data = data.AtomicData.from_config(
config_pt_head, z_table=z_table, cutoff=5.0, heads=["DFT"]
)
dataloader = torch_geometric.dataloader.DataLoader(
dataset=[atomic_data], batch_size=1, shuffle=False
)
batch = next(iter(dataloader))
output_new = new_model(batch.to_dict())
torch.testing.assert_close(
output_orig["energy"], output_new["energy"], rtol=1e-5, atol=1e-5
)
torch.testing.assert_close(
output_orig["forces"], output_new["forces"], rtol=1e-5, atol=1e-5
)
def test_remove_pt_head_multihead():
# Set up test data
torch.manual_seed(42)
atomic_energies_pt_head = np.array(
[
[1.0, 2.0], # H energies for each head
[3.0, 4.0], # O energies for each head
]
* 2
)
z_table = AtomicNumberTable([1, 8]) # H and O
# Create multihead model
model_config = {
"r_max": 5.0,
"num_bessel": 8,
"num_polynomial_cutoff": 5,
"max_ell": 2,
"interaction_cls": modules.interaction_classes[
"RealAgnosticResidualInteractionBlock"
],
"interaction_cls_first": modules.interaction_classes[
"RealAgnosticResidualInteractionBlock"
],
"num_interactions": 2,
"num_elements": len(z_table),
"hidden_irreps": o3.Irreps("32x0e + 32x1o"),
"MLP_irreps": o3.Irreps("16x0e"),
"gate": torch.nn.functional.silu,
"atomic_energies": atomic_energies_pt_head,
"avg_num_neighbors": 8,
"atomic_numbers": z_table.zs,
"correlation": 3,
"heads": ["pt_head", "DFT", "MP2", "CCSD"],
"atomic_inter_scale": [1.0, 1.0, 1.0, 1.0],
"atomic_inter_shift": [0.0, 0.1, 0.2, 0.3],
}
model = modules.ScaleShiftMACE(**model_config)
# Create test configurations for each head
mol = molecule("H2O")
configs = {}
atomic_datas = {}
dataloaders = {}
original_outputs = {}
# First get outputs from original model for each head
for head in model.heads:
config_pt_head = data.Configuration(
atomic_numbers=mol.numbers,
positions=mol.positions,
properties={"energy": 1.0, "forces": np.random.randn(len(mol), 3)},
property_weights={"forces": 1.0, "energy": 1.0},
head=head,
)
configs[head] = config_pt_head
atomic_data = data.AtomicData.from_config(
config_pt_head, z_table=z_table, cutoff=5.0, heads=model.heads
)
atomic_datas[head] = atomic_data
dataloader = torch_geometric.dataloader.DataLoader(
dataset=[atomic_data], batch_size=1, shuffle=False
)
dataloaders[head] = dataloader
batch = next(iter(dataloader))
output = model(batch.to_dict())
original_outputs[head] = output
# Now test each head separately
for i, head in enumerate(model.heads):
# Convert to single head model
new_model = remove_pt_head(model, head_to_keep=head)
# Basic structure tests
assert len(new_model.heads) == 1, f"Failed for head {head}"
assert new_model.heads[0] == head, f"Failed for head {head}"
assert (
new_model.atomic_energies_fn.atomic_energies.shape[0] == 1
), f"Failed for head {head}"
assert (
len(torch.atleast_1d(new_model.scale_shift.scale)) == 1
), f"Failed for head {head}"
assert (
len(torch.atleast_1d(new_model.scale_shift.shift)) == 1
), f"Failed for head {head}"
# Verify scale and shift values
assert torch.allclose(
new_model.scale_shift.scale, model.scale_shift.scale[i : i + 1]
), f"Failed for head {head}"
assert torch.allclose(
new_model.scale_shift.shift, model.scale_shift.shift[i : i + 1]
), f"Failed for head {head}"
# Test output consistency
single_head_data = data.AtomicData.from_config(
configs[head], z_table=z_table, cutoff=5.0, heads=[head]
)
single_head_loader = torch_geometric.dataloader.DataLoader(
dataset=[single_head_data], batch_size=1, shuffle=False
)
batch = next(iter(single_head_loader))
new_output = new_model(batch.to_dict())
# Compare outputs
print(
original_outputs[head]["energy"],
new_output["energy"],
)
torch.testing.assert_close(
original_outputs[head]["energy"],
new_output["energy"],
rtol=1e-5,
atol=1e-5,
msg=f"Energy mismatch for head {head}",
)
torch.testing.assert_close(
original_outputs[head]["forces"],
new_output["forces"],
rtol=1e-5,
atol=1e-5,
msg=f"Forces mismatch for head {head}",
)
# Test error cases
with pytest.raises(ValueError, match="Head non_existent not found in model"):
remove_pt_head(model, head_to_keep="non_existent")
# Test default behavior (first non-PT head)
default_model = remove_pt_head(model)
assert default_model.heads[0] == "DFT"
# Additional test: check if each model's computation graph is independent
models = {head: remove_pt_head(model, head_to_keep=head) for head in model.heads}
results = {}
for head, head_model in models.items():
single_head_data = data.AtomicData.from_config(
configs[head], z_table=z_table, cutoff=5.0, heads=[head]
)
single_head_loader = torch_geometric.dataloader.DataLoader(
dataset=[single_head_data], batch_size=1, shuffle=False
)
batch = next(iter(single_head_loader))
results[head] = head_model(batch.to_dict())
# Verify each model produces different outputs
energies = torch.stack([results[head]["energy"] for head in model.heads])
assert not torch.allclose(
energies[0], energies[1], rtol=1e-3
), "Different heads should produce different outputs"