Skip to content

Commit

Permalink
compute EMD with ITML and t-SNE ground metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
daureg committed Oct 31, 2014
1 parent 27d6a6c commit 76c0931
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
5 changes: 4 additions & 1 deletion ClosestNeighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ def gather_info(city, knn=2, mat=None, raw_features=True, hide_category=False):
mask = res['features'][:, 5] == cat
venues = matrix['i'][mask]
if len(venues) > 0:
idx_subset = np.ix_(mask, RESTRICTED) # pylint: disable=E1101
frange = np.array(range(len(FEATURES)))
if city.endswith('_tsne.mat'):
frange = np.arange(5)
idx_subset = np.ix_(mask, frange) # pylint: disable=E1101
algo = NN(knn) if mat is None else NN(knn, metric='mahalanobis',
VI=np.linalg.inv(mat))
res[int(cat)] = (algo.fit(res['features'][idx_subset]), venues)
Expand Down
19 changes: 13 additions & 6 deletions neighborhood.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# pylint: disable=E1101
# pylint: disable=W0621
NB_CLUSTERS = 3
JUST_READING = True
JUST_READING = False
MAX_EMD_POINTS = 750
NO_WEIGHT = True
QUERY_NAME = None
Expand Down Expand Up @@ -276,7 +276,9 @@ def brute_search(city_desc, hsize, distance_function, threshold,
def interpret_query(from_city, to_city, region, metric):
"""Load informations about cities and compute useful quantities."""
# Load info of the first city
left = cn.gather_info(from_city, knn=1, raw_features='lmnn' not in metric,
suffix = '_tsne.mat' if metric == 'emd-tsne' else ''
left = cn.gather_info(from_city+suffix, knn=1,
raw_features='lmnn' not in metric,
hide_category=metric != 'jsd')
left_infos = load_surroundings(from_city)
left_support = features_support(left['features'])
Expand All @@ -301,6 +303,10 @@ def interpret_query(from_city, to_city, region, metric):
if 'emd' in metric:
from emd import emd
from emd_dst import dist_for_emd
if 'tsne' in metric:
from specific_emd_dst import dst_tsne as dist_for_emd
if 'itml' in metric:
from specific_emd_dst import dst_itml as dist_for_emd
query_num = features_as_lists(features)

@profile
Expand Down Expand Up @@ -336,7 +342,8 @@ def regions_distance(r_density, r_global):
theta)

# Load info of the target city
right = cn.gather_info(to_city, knn=2, raw_features='lmnn' not in metric,
right = cn.gather_info(to_city+suffix, knn=2,
raw_features='lmnn' not in metric,
hide_category=metric != 'jsd')
right_infos = load_surroundings(to_city)
minx, miny, maxx, maxy = right_infos[1]
Expand All @@ -357,7 +364,7 @@ def best_match(from_city, to_city, region, tradius, progressive=False,
"""Try to match a `region` from `from_city` to `to_city`. If progressive,
yield intermediate result."""
assert metric in ['jsd', 'emd', 'jsd-nospace', 'jsd-greedy', 'cluster',
'leftover', 'emd-lmnn']
'leftover', 'emd-lmnn', 'emd-itml', 'emd-tsne']

infos = interpret_query(from_city, to_city, region, metric)
left, right, right_desc, regions_distance, vids, threshold = infos
Expand Down Expand Up @@ -605,7 +612,7 @@ def batch_matching(query_city='paris'):
cities = sorted(regions.values()[0]['gold'].keys())
assert query_city in cities
cities.remove(query_city)
OTMPDIR = os.path.join(OTMPDIR, 'comparaison_'+query_city)
OTMPDIR = os.path.join(OTMPDIR, 'www_comparaison_'+query_city)
try:
os.mkdir(OTMPDIR)
except OSError:
Expand All @@ -622,7 +629,7 @@ def batch_matching(query_city='paris'):
rgeo = choose_query_region(possible_regions)
if not rgeo:
continue
for metric in ['emd']:
for metric in ['emd-itml', 'emd-tsne']:
# for metric in ['jsd', 'emd', 'cluster', 'emd-lmnn', 'leftover']:
print(metric)
for radius in np.linspace(200, 500, 5):
Expand Down
20 changes: 20 additions & 0 deletions specific_emd_dst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#! /usr/bin/python2
# vim: set fileencoding=utf-8
"""Distance function between two venues for ITML and t-SNE."""
import scipy.io as sio
import numpy as np

COV = sio.loadmat('ITMLall.mat')['A']
# pylint: disable=E1101
COV = np.insert(COV, 5, values=0, axis=1)
COV = np.insert(COV, 5, values=0, axis=0)
COV[5, 5] = 1.0
COV = np.linalg.inv(COV)


def dst_itml(u, v, _):
return scipy.spatial.distance.mahalanobis(u, v, COV)


def dst_tsne(u, v, _):
return math.sqrt((u[0] - v[0])*(u[0] - v[0]) + (u[1] - v[1])*(u[1] - v[1]))

0 comments on commit 76c0931

Please sign in to comment.