Skip to content

Commit

Permalink
ENH: GEE add Wrapper, closes statsmodels#1904
Browse files Browse the repository at this point in the history
  • Loading branch information
josef-pkt committed Aug 22, 2014
1 parent 3062a80 commit 274c9fe
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 6 deletions.
27 changes: 24 additions & 3 deletions statsmodels/genmod/generalized_estimating_equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
from statsmodels.tools.decorators import (cache_readonly,
resettable_cache)
import statsmodels.base.model as base
# used for wrapper:
import statsmodels.regression.linear_model as lm
import statsmodels.base.wrapper as wrap

from statsmodels.genmod import families
from statsmodels.genmod import dependence_structures
from statsmodels.genmod.dependence_structures import CovStruct
Expand Down Expand Up @@ -935,7 +939,7 @@ def fit(self, maxiter=60, ctol=1e-6, start_params=None,
"params_niter", "first_dep_update", "ctol",
"maxiter"]

return results
return GEEResultsWrapper(results)

fit.__doc__ = _gee_fit_doc

Expand Down Expand Up @@ -1453,6 +1457,13 @@ def params_sensitivity(self, dep_params_first,

return results

class GEEResultsWrapper(lm.RegressionResultsWrapper):
_attrs = {
'centered_resid' : 'rows',
}
_wrap_attrs = wrap.union_dicts(lm.RegressionResultsWrapper._wrap_attrs,
_attrs)
wrap.populate_wrapper(GEEResultsWrapper, GEEResults)


class OrdinalGEE(GEE):
Expand Down Expand Up @@ -1555,6 +1566,7 @@ def fit(self, maxiter=60, ctol=1e-6, start_params=None,
params_niter, first_dep_update,
cov_type=cov_type)

rslt = rslt._results # use unwrapped instance
res_kwds = dict(((k, getattr(rslt, k)) for k in rslt._props))
# Convert the GEEResults to an OrdinalGEEResults
ord_rslt = OrdinalGEEResults(self, rslt.params,
Expand All @@ -1565,10 +1577,11 @@ def fit(self, maxiter=60, ctol=1e-6, start_params=None,
#for k in rslt._props:
# setattr(ord_rslt, k, getattr(rslt, k))

return ord_rslt
return OrdinalGEEResultsWrapper(ord_rslt)

fit.__doc__ = _gee_fit_doc


class OrdinalGEEResults(GEEResults):

__doc__ = (
Expand Down Expand Up @@ -1663,6 +1676,10 @@ def plot_distribution(self, ax=None, exog_values=None):

return fig

class OrdinalGEEResultsWrapper(GEEResultsWrapper):
pass
wrap.populate_wrapper(OrdinalGEEResultsWrapper, OrdinalGEEResults)


class NominalGEE(GEE):

Expand Down Expand Up @@ -1769,6 +1786,7 @@ def fit(self, maxiter=60, ctol=1e-6, start_params=None,
ConvergenceWarning)
return None

rslt = rslt._results # use unwrapped instance
res_kwds = dict(((k, getattr(rslt, k)) for k in rslt._props))
# Convert the GEEResults to a NominalGEEResults
nom_rslt = NominalGEEResults(self, rslt.params,
Expand All @@ -1779,7 +1797,7 @@ def fit(self, maxiter=60, ctol=1e-6, start_params=None,
#for k in rslt._props:
# setattr(nom_rslt, k, getattr(rslt, k))

return nom_rslt
return NominalGEEResultsWrapper(nom_rslt)

fit.__doc__ = _gee_fit_doc

Expand Down Expand Up @@ -1870,6 +1888,9 @@ def plot_distribution(self, ax=None, exog_values=None):

return fig

class NominalGEEResultsWrapper(GEEResultsWrapper):
pass
wrap.populate_wrapper(NominalGEEResultsWrapper, NominalGEEResults)


class MultinomialLogit(Link):
Expand Down
72 changes: 69 additions & 3 deletions statsmodels/genmod/tests/test_gee.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
assert_array_less, assert_raises, assert_)
from statsmodels.genmod.generalized_estimating_equations import (GEE,
OrdinalGEE, NominalGEE, GEEMargins, Multinomial,
NominalGEEResults, OrdinalGEEResults)
NominalGEEResults, OrdinalGEEResults,
NominalGEEResultsWrapper, OrdinalGEEResultsWrapper)
from statsmodels.genmod.families import Gaussian, Binomial, Poisson
from statsmodels.genmod.dependence_structures import (Exchangeable,
Independence, GlobalOddsRatio, Autoregressive, Nested)
Expand Down Expand Up @@ -52,6 +53,19 @@ def load_data(fname, icept=True):
return endog,exog,group


def check_wrapper(results):
# check wrapper
assert_(isinstance(results.params, pd.Series))
assert_(isinstance(results.fittedvalues, pd.Series))
assert_(isinstance(results.resid, pd.Series))
assert_(isinstance(results.centered_resid, pd.Series))

assert_(isinstance(results._results.params, np.ndarray))
assert_(isinstance(results._results.fittedvalues, np.ndarray))
assert_(isinstance(results._results.resid, np.ndarray))
assert_(isinstance(results._results.centered_resid, np.ndarray))


class TestGEE(object):


Expand Down Expand Up @@ -559,7 +573,8 @@ def test_ordinal(self):
assert_almost_equal(rslt.bse, se, decimal=5)

# Check that we get the correct results type
assert_equal(type(rslt), OrdinalGEEResults)
assert_equal(type(rslt), OrdinalGEEResultsWrapper)
assert_equal(type(rslt._results), OrdinalGEEResults)

def test_nominal(self):

Expand Down Expand Up @@ -591,7 +606,8 @@ def test_nominal(self):
assert_almost_equal(rslt2.standard_errors(), se2, decimal=5)

# Make sure we get the correct results type
assert_equal(type(rslt1), NominalGEEResults)
assert_equal(type(rslt1), NominalGEEResultsWrapper)
assert_equal(type(rslt1._results), NominalGEEResults)


def test_poisson(self):
Expand Down Expand Up @@ -745,6 +761,9 @@ def test_formulas(self):
assert_almost_equal(rslt1.params, rslt4.params, decimal=8)
assert_almost_equal(rslt1.params, rslt5.params, decimal=8)

check_wrapper(rslt2)


def test_compare_logit(self):

vs = Independence()
Expand Down Expand Up @@ -883,6 +902,21 @@ def setup_class(cls):
0.57628591, -0.0046566, -0.47709315])


def test_wrapper(self):

endog, exog, group_n = load_data("gee_poisson_1.csv",
icept=False)
#endog = pd.Series(endog)
exog = pd.DataFrame(exog)

family = Poisson()
vi = Independence()

mod = GEE(endog, exog, group_n, None, family, vi)
rslt2 = mod.fit()

check_wrapper(rslt2)


class TestGEEPoissonFormulaCovType(CheckConsistency):

Expand Down Expand Up @@ -925,6 +959,22 @@ def setup_class(cls):
-0.01812116, 0.03023969, 1.18258516,
0.01803453, -1.10203381])

def test_wrapper(self):

endog, exog, groups = load_data("gee_ordinal_1.csv",
icept=False)


endog = pd.Series(endog)
exog = pd.DataFrame(exog)

family = Binomial()
va = GlobalOddsRatio("ordinal")
mod = OrdinalGEE(endog, exog, groups, None, family, va)
rslt2 = mod.fit()

check_wrapper(rslt2)


class TestGEEMultinomialCovType(CheckConsistency):

Expand All @@ -943,6 +993,22 @@ def setup_class(cls):
-0.46766728])


def test_wrapper(self):

endog, exog, groups = load_data("gee_nominal_1.csv",
icept=False)
#endog = pd.Series(endog)
exog = pd.DataFrame(exog)

family = Multinomial(3)
va = Independence()
mod = NominalGEE(endog, exog, groups, None, family, va)
rslt2 = mod.fit()

check_wrapper(rslt2)



if __name__=="__main__":

import nose
Expand Down

0 comments on commit 274c9fe

Please sign in to comment.