Skip to content

Commit

Permalink
Merge branch 'master' of github.com:/proycon/python-timbl
Browse files Browse the repository at this point in the history
  • Loading branch information
proycon committed Dec 8, 2019
2 parents 629fd8f + e3c8767 commit 068a0b5
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 19 deletions.
5 changes: 5 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
:alt: Project Status: Active – The project has reached a stable, usable state and is being actively developed.
:target: https://www.repostatus.org/#active

.. image:: https://zenodo.org/badge/8136669.svg
:target: https://zenodo.org/badge/latestdoi/8136669

======================
README: python-timbl
======================
Expand Down Expand Up @@ -203,4 +206,6 @@ manually call the ``initthreading()`` method.
Three TiMBL API methods print information to a standard C++ output stream object (ShowBestNeighbors, ShowOptions, ShowSettings, ShowSettings). In the Python interface, these methods will only work with Python (stream) objects that have a fileno method returning a valid file descriptor. Alternatively, three new methods are provided (bestNeighbo(u)rs, options, settings); these methods return the same information as a Python string object.


**scikit-learn wrapper**

A wrapper for use in scikit-learn has been added. It was designed for use in scikit-learn Pipeline objects. The wrapper is not finished and has to date only been tested on sparse data. Note that TiMBL does not work well with large amounts of features. It is suggested to reduce the amount of features to a number below 100 to keep system performance reasonable. Use on servers with large amounts of memory and processing cores advised.
60 changes: 41 additions & 19 deletions timbl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
stderr = sys.stderr
stdout = sys.stdout

from tempfile import mktemp
import timblapi
import io
import os
Expand Down Expand Up @@ -59,15 +60,19 @@ def u(s, encoding = 'utf-8', errors='strict'):


class TimblClassifier(object):
def __init__(self, fileprefix, timbloptions, format = "Tabbed", dist=True, encoding = 'utf-8', overwrite = True, flushthreshold=10000, threading=False, normalize=True, debug=False):
def __init__(self, fileprefix, timbloptions, format = "Tabbed", dist=True, encoding = 'utf-8', overwrite = True, flushthreshold=10000, threading=False, normalize=True, debug=False, sklearn=False, flushdir=None):
if format.lower() == "tabbed":
self.format = "Tabbed"
self.delimiter = "\t"
elif format.lower() == "columns":
self.format = "Columns"
self.delimiter = " "
elif format.lower() == 'sparse': # for sparse arrays, e.g. scipy.sparse.csr
self.format = "Sparse"
self.delimiter = ""
else:
raise ValueError("Only Tabbed and Columns are supported input format for the python wrapper, not " + format)
raise ValueError("Only Tabbed, Columns, and Sparse are supported input format for the python wrapper, not " + format)

self.timbloptions = timbloptions
self.fileprefix = fileprefix

Expand All @@ -80,11 +85,17 @@ def __init__(self, fileprefix, timbloptions, format = "Tabbed", dist=True, encod
self.instances = []
self.api = None
self.debug = debug
self.sklearn = sklearn

if os.path.exists(self.fileprefix + ".train") and overwrite:
if sklearn:
import scipy as sp
self.flushfile = mktemp(prefix=self.fileprefix, dir=flushdir)
self.flushed = 0
else:
self.flushed = 1
if os.path.exists(self.fileprefix + ".train") and overwrite:
self.flushed = 0
else:
self.flushed = 1

self.threading = threading

Expand All @@ -94,8 +105,10 @@ def validatefeatures(self,features):
for feature in features:
if isinstance(feature, int) or isinstance(feature, float):
validatedfeatures.append( str(feature) )
elif self.delimiter in feature:
elif self.delimiter in feature and not self.sklearn:
raise ValueError("Feature contains delimiter: " + feature)
elif self.sklearn and isinstance(feature, str): #then is sparse added together
validatedfeatures.append(feature)
else:
validatedfeatures.append(feature)
return validatedfeatures
Expand All @@ -106,21 +119,24 @@ def append(self, features, classlabel):

features = self.validatefeatures(features)

if self.delimiter in classlabel:
if self.delimiter in classlabel and self.delimiter != '':
raise ValueError("Class label contains delimiter: " + self.delimiter)

self.instances.append(self.delimiter.join(features) + self.delimiter + classlabel)
self.instances.append(self.delimiter.join(features) + (self.delimiter if not self.delimiter == '' else ' ') + classlabel)
if len(self.instances) >= self.flushthreshold:
self.flush()

def flush(self):
if self.debug: print("Flushing...",file=sys.stderr)
if len(self.instances) == 0: return False

if self.flushed:
f = io.open(self.fileprefix + ".train",'a', encoding=self.encoding)
if hasattr(self, 'flushfile'):
f = io.open(self.flushfile,'w', encoding=self.encoding)
else:
f = io.open(self.fileprefix + ".train",'w', encoding=self.encoding)
if self.flushed:
f = io.open(self.fileprefix + ".train",'a', encoding=self.encoding)
else:
f = io.open(self.fileprefix + ".train",'w', encoding=self.encoding)

for instance in self.instances:
f.write(instance + "\n")
Expand All @@ -135,8 +151,18 @@ def __delete__(self):

def train(self, save=False):
self.flush()
if not os.path.exists(self.fileprefix + ".train"):
raise LoadException("Training file '"+self.fileprefix+".train' not found. Did you forget to add instances with append()?")

if hasattr(self, 'flushfile'):
if not os.path.exists(self.flushfile):
raise LoadException("Training file '"+self.flushfile+"' not found. Did you forget to add instances with append()?")
else:
filepath = self.flushfile
else:
if not os.path.exists(self.fileprefix + ".train"):
raise LoadException("Training file '"+self.fileprefix+".train' not found. Did you forget to add instances with append()?")
else:
filepath = self.fileprefix + '.train'

options = "-F " + self.format + " " + self.timbloptions
if self.dist:
options += " +v+db +v+di"
Expand All @@ -149,7 +175,7 @@ def train(self, save=False):
print("Enabling debug for timblapi",file=stderr)
self.api.enableDebug()

trainfile = self.fileprefix + ".train"
trainfile = filepath
self.api.learn(b(trainfile))
if save:
self.save()
Expand All @@ -168,7 +194,8 @@ def classify(self, features, allowtopdistribution=True):

if not self.api:
self.load()
testinstance = self.delimiter.join(features) + self.delimiter + "?"

testinstance = self.delimiter.join(features) + (self.delimiter if not self.delimiter == '' else ' ') + "?"
if self.dist:
if self.threading:
result, cls, distribution, distance = self.api.classify3safe(b(testinstance), self.normalize, int(not allowtopdistribution))
Expand Down Expand Up @@ -347,8 +374,3 @@ def _parsedistribution(self, instance, start=0, end =None):

return dist






122 changes: 122 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils import check_X_y, check_array
from timbl import TimblClassifier
import scipy as sp
import numpy as np

class skTiMBL(BaseEstimator, ClassifierMixin):
def __init__(self, prefix='timbl', algorithm=4, dist_metric=None,
k=1, normalize=False, debug=0, flushdir=None):
self.prefix = prefix
self.algorithm = algorithm
self.dist_metric = dist_metric
self.k = k
self.normalize = normalize
self.debug = debug
self.flushdir = flushdir


def _make_timbl_options(self, *options):
"""
-a algorithm
-m metric
-w weighting
-k amount of neighbours
-d class voting weights
-L frequency threshold
-T which feature index is label
-N max number of features
-H turn hashing on/off
This function still has to be made, for now the appropriate arguments
can be passed in fit()
"""
pass


def fit(self, X, y):
X, y = check_X_y(X, y, dtype=np.int64, accept_sparse='csr')

n_rows = X.shape[0]
self.classes_ = np.unique(y)

if sp.sparse.issparse(X):
if self.debug: print('Features are sparse, choosing faster learning')

self.classifier = TimblClassifier(self.prefix, "-a{} -k{} -N{} -vf".format(self.algorithm,self.k, X.shape[1]),
format='Sparse', debug=True, sklearn=True, flushdir=self.flushdir,
flushthreshold=20000, normalize=self.normalize)

for i in range(n_rows):
sparse = ['({},{})'.format(i+1, c) for i,c in zip(X[i].indices, X[i].data)]
self.classifier.append(sparse,str(y[i]))

else:

self.classifier = TimblClassifier(self.prefix, "-a{} -k{} -N{} -vf".format(self.algorithm, self.k, X.shape[1]),
debug=True, sklearn=True, flushdir=self.flushdir, flushthreshold=20000,
normalize=self.normalize)

if y.dtype != 'O':
y = y.astype(str)

for i in range(n_rows):
self.classifier.append(list(X[i].toarray()[0]), y[i])

self.classifier.train()
return self


def _timbl_predictions(self, X, part_index, y=None):
choices = {0 : lambda x : x.append(np.int64(label)),
1 : lambda x : x.append([np.float(distance)]),
}
X = check_array(X, dtype=np.float64, accept_sparse='csr')

n_samples = X.shape[0]

pred = []
func = choices[part_index]
if sp.sparse.issparse(X):
if self.debug: print('Features are sparse, choosing faster predictions')

for i in range(n_samples):
sparse = ['({},{})'.format(i+1, c) for i,c in zip(X[i].indices, X[i].data)]
label,proba, distance = self.classifier.classify(sparse)
func(pred)

else:
for i in range(n_samples):
label,proba, distance = self.classifier.classify(list(X[i].toarray()[0]))
func(pred)

return np.array(pred)



def predict(self, X, y=None):
return self._timbl_predictions(X, part_index=0)


def predict_proba(self, X, y=None):
"""
TIMBL is a discrete classifier. It cannot give probability estimations.
To ensure that scikit-learn functions with TIMBL (and especially metrics
such as ROC_AUC), this method is implemented.
For ROC_AUC, the classifier corresponds to a single point in ROC space,
instead of a probabilistic continuum such as classifiers that can give
a probability estimation (e.g. Linear classifiers). For an explanation,
see Fawcett (2005).
"""
return predict(X)


def decision_function(self, X, y=None):
"""
The decision function is interpreted here as being the distance between
the instance that is being classified and the nearest point in k space.
"""
return self._timbl_predictions(X, part_index=1)


0 comments on commit 068a0b5

Please sign in to comment.