Skip to content

Commit

Permalink
checkin copy
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed May 18, 2015
1 parent e6b8b23 commit 91a5390
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 13 deletions.
1 change: 0 additions & 1 deletion demo/guide-python/basic_walkthrough.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import scipy.sparse
import pickle
import xgboost as xgb
import copy

### simple example
# load file from text file, also binary buffer generated by xgboost
Expand Down
61 changes: 49 additions & 12 deletions wrapper/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def __init__(self, data, label=None, missing=0.0, weight=None):
weight : list or numpy 1-D array (optional)
Weight for each instance.
"""

# force into void_p, mac need to pass things in as void_p
if data is None:
self.handle = None
Expand Down Expand Up @@ -348,6 +347,46 @@ def __init__(self, params=None, cache=(), model_file=None):
def __del__(self):
xglib.XGBoosterFree(self.handle)

def __getstate__(self):
# can't pickle ctypes pointers
# put model content in bytearray
this = self.__dict__.copy()
handle = this['handle']
if handle is not None:
raw = self.save_raw()
this["handle"] = raw
return this

def __setstate__(self, state):
# reconstruct handle from raw data
handle = state['handle']
if handle is not None:
buf = handle
dmats = c_array(ctypes.c_void_p, [])
handle = ctypes.c_void_p(xglib.XGBoosterCreate(dmats, 0))
length = ctypes.c_ulong(len(buf))
ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
xglib.XGBoosterLoadModelFromBuffer(handle, ptr, length)
state['handle'] = handle
self.__dict__.update(state)
self.set_param({'seed': 0})

def __copy__(self):
return self.__deepcopy__()

def __deepcopy__(self):
return Booster(model_file = self.save_raw())

def copy(self):
"""
Copy the booster object
Returns
--------
a copied booster model
"""
return self.__copy__()

def set_param(self, params, pv=None):
if isinstance(params, collections.Mapping):
params = params.items()
Expand Down Expand Up @@ -440,6 +479,11 @@ def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False):
"""
Predict with data.
NOTE: This function is not thread safe.
For each booster object, predict can only be called from one thread.
If you want to run prediction using multiple thread, call bst.copy() to make copies
of model object and then call predict
Parameters
----------
data : DMatrix
Expand Down Expand Up @@ -874,18 +918,12 @@ def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, silent=True

self._Booster = None

def __getstate__(self):
# can't pickle ctypes pointers so put _Booster in a bytearray object
this = self.__dict__.copy() # don't modify in place
bst = this["_Booster"]
if bst is not None:
raw = this["_Booster"].save_raw()
this["_Booster"] = raw
return this

def __setstate__(self, state):
# backward compatiblity code
# load booster from raw if it is raw
# the booster now support pickle
bst = state["_Booster"]
if bst is not None:
if bst is not None and not isinstance(bst, Booster):
state["_Booster"] = Booster(model_file=bst)
self.__dict__.update(state)

Expand Down Expand Up @@ -977,7 +1015,6 @@ def predict_proba(self, X):
classzero_probs = 1.0 - classone_probs
return np.vstack((classzero_probs, classone_probs)).transpose()


class XGBRegressor(XGBModel, XGBRegressor):
__doc__ = """
Implementation of the scikit-learn API for XGBoost regression
Expand Down

0 comments on commit 91a5390

Please sign in to comment.