-
Notifications
You must be signed in to change notification settings - Fork 236
/
Copy pathtest_simple_2d_data.py
45 lines (35 loc) · 1.43 KB
/
test_simple_2d_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import numpy as np
from dtw import dtw
def l2_norm(x, y):
return (x - y) ** 2
def test_simple_2d_data():
x = np.array([2, 0, 1, 1, 2, 4, 2, 1, 2, 0]).reshape(-1, 1)
y = np.array([1, 1, 2, 4, 2, 1, 2, 0]).reshape(-1, 1)
dist, cost_matrix, acc_cost_matrix, path = dtw(x, y, dist=l2_norm)
assert dist == 2
assert (cost_matrix == np.array(
[[1., 1., 0., 4., 0., 1., 0., 4.],
[1., 1., 4., 16., 4., 1., 4., 0.],
[0., 0., 1., 9., 1., 0., 1., 1.],
[0., 0., 1., 9., 1., 0., 1., 1.],
[1., 1., 0., 4., 0., 1., 0., 4.],
[9., 9., 4., 0., 4., 9., 4., 16.],
[1., 1., 0., 4., 0., 1., 0., 4.],
[0., 0., 1., 9., 1., 0., 1., 1.],
[1., 1., 0., 4., 0., 1., 0., 4.],
[1., 1., 4., 16., 4., 1., 4., 0.]],
)).all()
assert (acc_cost_matrix == np.array(
[[1., 2., 2., 6., 6., 7., 7., 11.],
[2., 2., 6., 18., 10., 7., 11., 7.],
[2., 2., 3., 12., 11., 7., 8., 8.],
[2., 2., 3., 12., 12., 7., 8., 9.],
[3., 3., 2., 6., 6., 7., 7., 11.],
[12., 12., 6., 2., 6., 15., 11., 23.],
[13., 13., 6., 6., 2., 3., 3., 7.],
[13., 13., 7., 15., 3., 2., 3., 4.],
[14., 14., 7., 11., 3., 3., 2., 6.],
[15., 15., 11., 23., 7., 4., 6., 2.]]
)).all()
assert (path[0] == np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])).all()
assert (path[1] == np.array([0, 0, 0, 1, 2, 3, 4, 5, 6, 7])).all()