-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcmd.py
36 lines (29 loc) · 831 Bytes
/
cmd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# # -*- coding: utf-8 -*-
import itertools
from torch.utils import data
import torch
def l2diff(x1, x2):
"""
standard euclidean norm
"""
return ((x1-x2)**2).sum().sqrt()
def moment_diff(sx1, sx2, k):
"""
difference between moments
"""
ss1 = (sx1**k).mean(0)
ss2 = (sx2**k).mean(0)
return l2diff(ss1, ss2)
class CMD(object):
def __init__(self, n_moments=5):
self.n_moments = n_moments
def forward(self, x1, x2):
mx1 = x1.mean(dim=0)
mx2 = x2.mean(dim=0)
sx1 = x1 - mx1
sx2 = x2 - mx2
scms = l2diff(mx1.detach(), mx2.detach()) # detach for avoid inplace gradient operation
for i in range(self.n_moments-1):
# moment diff of centralized samples
scms += moment_diff(sx1, sx2, i+2)
return scms