Expand source code
from typing import Type

import numpy as np

from ..sklearnmodel import SklearnModel


class OLS(SklearnModel):

    def __init__(self,
                 stat_model: Type,
                 **kwargs):
        self.stat_model = stat_model
        self.stat_model_fit = None
        super().__init__(**kwargs)

    def fit(self, X: np.ndarray, y: np.ndarray) -> 'OLS':
        self.stat_model_fit = self.stat_model(y, X).fit()
        SklearnModel.fit(self, X, self.stat_model_fit.resid)
        return self

    def predict(self, X: np.ndarray=None) -> np.ndarray:
        if X is None:
            X = self.data.X
        sm_prediction = self.stat_model_fit.predict(X)
        bart_prediction = SklearnModel.predict(self, X)
        return sm_prediction + bart_prediction

Classes

class OLS (stat_model: Type, **kwargs)

The main access point to building BART models in BartPy

Parameters

n_trees : int
the number of trees to use, more trees will make a smoother fit, but slow training and fitting
n_chains : int
the number of independent chains to run more chains will improve the quality of the samples, but will require more computation
sigma_a : float
shape parameter of the prior on sigma
sigma_b : float
scale parameter of the prior on sigma
n_samples : int
how many recorded samples to take
n_burn : int
how many samples to run without recording to reach convergence
thin : float
percentage of samples to store. use this to save memory when running large models
p_grow : float
probability of choosing a grow mutation in tree mutation sampling
p_prune : float
probability of choosing a prune mutation in tree mutation sampling
alpha : float
prior parameter on tree structure
beta : float
prior parameter on tree structure
store_in_sample_predictions : bool
whether to store full prediction samples set to False if you don't need in sample results - saves a lot of memory
store_acceptance_trace : bool
whether to store acceptance rates of the gibbs samples unless you're very memory constrained, you wouldn't want to set this to false useful for diagnostics
tree_sampler : TreeMutationSampler
Method of sampling used on trees defaults to bartpy.samplers.unconstrainedtree
initializer : Initializer
Class that handles the initialization of tree structure and leaf values
n_jobs : int
how many cores to use when computing MCMC samples set to -1 to use all cores
Expand source code
class OLS(SklearnModel):

    def __init__(self,
                 stat_model: Type,
                 **kwargs):
        self.stat_model = stat_model
        self.stat_model_fit = None
        super().__init__(**kwargs)

    def fit(self, X: np.ndarray, y: np.ndarray) -> 'OLS':
        self.stat_model_fit = self.stat_model(y, X).fit()
        SklearnModel.fit(self, X, self.stat_model_fit.resid)
        return self

    def predict(self, X: np.ndarray=None) -> np.ndarray:
        if X is None:
            X = self.data.X
        sm_prediction = self.stat_model_fit.predict(X)
        bart_prediction = SklearnModel.predict(self, X)
        return sm_prediction + bart_prediction

Ancestors

  • SklearnModel
  • sklearn.base.BaseEstimator
  • sklearn.utils._estimator_html_repr._HTMLDocumentationLinkMixin
  • sklearn.utils._metadata_requests._MetadataRequester
  • sklearn.base.RegressorMixin

Inherited members