Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Iterative rectifier #534

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
improve rectifier
  • Loading branch information
leostimpfle committed Jul 2, 2024
commit b4ee73835bad5f304c795c77331da31443b02ca9
51 changes: 34 additions & 17 deletions pyfixest/estimation/fepois_.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,8 @@ def _check_for_separation(
fe : pd.DataFrame
Fixed effects.
method: list[str], optional
Methods to check for separation. Either "fe" or "ir". Executes both by default.
Methods used to check for separation. One of fixed effects ("fe") or
iterative rectifier ("ir"). Executes all methods by default.

Returns
-------
Expand All @@ -351,7 +352,7 @@ def _check_for_separation(
invalid_methods = [method for method in methods if method not in valid_methods]
if invalid_methods:
raise ValueError(
f"Invalid separation method. Expecting {valid_methods}. Got {invalid_methods}"
f"Invalid separation method. Expecting {list(valid_methods)}. Received {invalid_methods}"
)

separation_na: set[int] = set()
Expand All @@ -378,7 +379,7 @@ def __call__(self, Y: pd.DataFrame, X: pd.DataFrame, fe: pd.DataFrame) -> set[in
Returns
-------
set
Set of indices of observations that are removed due to separation.
Set of indices of separated observations.
"""
...

Expand All @@ -401,7 +402,7 @@ def _check_for_separation_fe(
Returns
-------
set
Set of indices of observations that are removed due to separation.
Set of indices of separated observations.
"""
separation_na: set[int] = set()
if fe is not None and not (Y > 0).all(axis=0).all():
Expand Down Expand Up @@ -456,42 +457,58 @@ def _check_for_separation_ir(
Returns
-------
set
Set of indices of observations that are removed due to separation.
Set of indices of separated observations.
"""
# lazy load and avoid circular import
fixest_module = import_module("pyfixest.estimation")
feols = getattr(fixest_module, "feols")

U = (Y == 0).astype(int).squeeze().rename("U")
# initialize variables
U = (Y == 0).astype(float).squeeze().rename("U")
N0 = (Y > 0).squeeze().sum()
K = N0 / tol**2
omega = pd.Series(np.where(Y.squeeze() > 0, K, 1), name="omega")
data = pd.concat([U, X, omega], axis=1)
# build formula
fml = "U"
if X is not None and not X.empty:
fml += f" ~ {' + '.join(X.columns[X.columns!='Intercept'])}"
if fe is not None and not fe.empty:
# TODO: can this rename be avoided?
fe.columns = fe.columns.str.replace("factorize", "").str.strip(r"\(|\)")
fml += f" | {' + '.join(fe.columns)}"

# construct data
data = pd.concat([U, X, fe, omega], axis=1)
separation_na: set[int] = set()
if (Y.squeeze() > 0).all():
is_interior = Y.squeeze() > 0
if is_interior.all():
# no boundary sample, can exit
# TODO: check if algorithm can be sped up based on boundary sample
return separation_na

iter = 0
while iter < maxiter:
iter += 1
iteration = 0
has_converged = False
while iteration < maxiter:
iteration += 1
# regress U on X
# TODO: check acceleration in ppmlhdfe's implementation: https://github.com/sergiocorreia/ppmlhdfe/blob/master/src/ppmlhdfe_separation_relu.mata#L135
Uhat = feols(fml, data=data, weights="omega").predict()
Uhat = np.where(np.abs(Uhat) < tol, 0, Uhat)
# update when within tolerance of zero
# need to be more strict below zero, to avoid false positives
within_zero = (Uhat > -0.1 * tol) & (Uhat < tol)
Uhat = np.where(is_interior | within_zero, 0, Uhat)
if (Uhat >= 0).all():
# all separated observations have been identified
has_converged = True
break
data["U"] = np.fmax(Uhat, 0) # rectified linear unit (ReLU)
data.loc[~is_interior, "U"] = np.fmax(
Uhat[~is_interior], 0
) # rectified linear unit (ReLU)

if has_converged:
separation_na = set(Y[Uhat > 0].index)
else:
warnings.warn(
"iterative rectivier separation check: maximum number of iterations reached before convergence"
)

separation_na = set(Y[Uhat > 0].index)
return separation_na


Expand Down