-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfeynman_test.py
79 lines (66 loc) · 1.78 KB
/
feynman_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from feynmandiag import init_feynfunc
import torch
from MCintegration import (
MonteCarlo,
MarkovChainMonteCarlo,
Vegas,
set_seed,
get_device,
)
import time
device = get_device()
batch_size = 5000
n_eval = 1000000
n_therm = 10
num_roots = [1, 2, 3, 4, 5, 6]
order = 2
beta = 10.0
feynfunc = init_feynfunc(order, beta, batch_size)
feynfunc.to(device)
f_dim = num_roots[order - 1]
vegas_map = Vegas(feynfunc.ndims, ninc=1000, device=device)
# vegas_map = Vegas(feynfunc.ndims, ninc=1000, device=torch.device("cpu"))
print("Training the vegas map...")
# feynfunc.to(torch.device("cpu"))
begin_time = time.time()
vegas_map.adaptive_training(batch_size, feynfunc, f_dim=f_dim, epoch=10, alpha=1.0)
print("training time: ", time.time() - begin_time, "s\n")
begin_time = time.time()
bounds = [[0, 1]] * feynfunc.ndims
mc_integrator = MonteCarlo(
bounds, feynfunc, f_dim=f_dim, batch_size=batch_size, device=device
)
res = mc_integrator(neval=n_eval, mix_rate=0.5)
print("Plain MC Integral results: ", res)
mcmc_integrator = MarkovChainMonteCarlo(
bounds,
feynfunc,
f_dim=f_dim,
batch_size=batch_size,
nburnin=n_therm,
device=device,
)
res = mcmc_integrator(neval=n_eval)
print("MCMC Integral results: ", res)
vegas_integrator = MonteCarlo(
bounds,
feynfunc,
f_dim=f_dim,
maps=vegas_map,
batch_size=batch_size,
device=device,
)
res = vegas_integrator(neval=n_eval, mix_rate=0.5)
print("VEGAS Integral results: ", res)
vegasmcmc_integrator = MarkovChainMonteCarlo(
bounds,
feynfunc,
f_dim=f_dim,
batch_size=batch_size,
nburnin=n_therm,
maps=vegas_map,
device=device,
)
res = vegasmcmc_integrator(neval=n_eval, mix_rate=0.5)
print("VEGAS-MCMC Integral results: ", res)
print("Total time: ", time.time() - begin_time, "s")