Skip to content

Commit

Permalink
Added prediction function.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander März committed May 23, 2023
1 parent 02e25f3 commit 0f268f9
Showing 1 changed file with 73 additions and 0 deletions.
73 changes: 73 additions & 0 deletions xgboostlss/distributions/distribution_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,79 @@ def draw_samples(self,

return dist_samples

def predict_dist(self,
booster: xgb.Booster,
dtest: xgb.DMatrix,
pred_type: str = "parameters",
n_samples: int = 1000,
quantiles: list = [0.1, 0.5, 0.9],
seed: str = 123
) -> pd.DataFrame:
"""
Function that predicts from the trained model.
Arguments
---------
booster : xgb.Booster
Trained model.
dtest : xgb.DMatrix
Test data.
pred_type : str
Type of prediction:
- "samples" draws n_samples from the predicted distribution.
- "quantile" calculates the quantiles from the predicted distribution.
- "parameters" returns the predicted distributional parameters.
- "expectiles" returns the predicted expectiles.
n_samples : int
Number of samples to draw from the predicted distribution.
quantiles : List[float]
List of quantiles to calculate from the predicted distribution.
seed : int
Seed for random number generator used to draw samples from the predicted distribution.
Returns
-------
pred : pd.DataFrame
Predictions.
"""
predt = np.array(booster.predict(dtest, output_margin=True)).reshape(-1, self.n_dist_param)
predt = torch.tensor(predt, dtype=torch.float32)

# Transform predicted parameters to response scale
dist_params_predt = np.concatenate(
[
response_fun(
predt[:, i].reshape(-1, 1)).numpy() for i, (dist_param, response_fun) in
enumerate(self.param_dict.items())
],
axis=1,
)
dist_params_predt = pd.DataFrame(dist_params_predt)
dist_params_predt.columns = self.param_dict.keys()

# Draw samples from predicted response distribution
pred_samples_df = self.draw_samples(predt_params=dist_params_predt,
n_samples=n_samples,
seed=seed)

# Calculate quantiles from predicted response distribution
pred_quant_df = pred_samples_df.quantile(quantiles, axis=1).T
pred_quant_df.columns = [str("quant_") + str(quantiles[i]) for i in range(len(quantiles))]
if self.discrete:
pred_quant_df = pred_quant_df.astype(int)

if pred_type == "parameters":
return dist_params_predt

elif pred_type == "expectiles":
return dist_params_predt

elif pred_type == "samples":
return pred_samples_df

elif pred_type == "quantiles":
return pred_quant_df


def compute_gradients_and_hessians(nll: torch.tensor,
predt: torch.tensor,
Expand Down

0 comments on commit 0f268f9

Please sign in to comment.