Skip to content

Commit da6aa02

Browse files
committed
add increasing alpha exponentially to avoid singular mat issue
1 parent 5ed7030 commit da6aa02

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ pandas
22
scipy
33
shap
44
statsmodels
5+
numpy

shap_select/select.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import statsmodels.api as sm
55
import scipy.stats as stats
66
import shap
7+
import numpy as np
78

89

910
def create_shap_features(
@@ -65,9 +66,20 @@ def binary_classifier_significance(
6566
# Add a constant to the features for the intercept in logistic regression
6667
shap_features_with_constant = sm.add_constant(shap_features)
6768

68-
# Fit the logistic regression model that will generate confidence intervals
69-
logit_model = sm.Logit(target, shap_features_with_constant)
70-
result = logit_model.fit_regularized(disp=False, alpha=alpha)
69+
alpha_in_loop = alpha
70+
max_retries = 10 # Set a maximum number of retries to avoid infinite loops
71+
for attempt in range(max_retries):
72+
try:
73+
# Fit the logistic regression model that will generate confidence intervals
74+
logit_model = sm.Logit(target, shap_features_with_constant)
75+
result = logit_model.fit_regularized(disp=False, alpha=alpha_in_loop)
76+
break # Exit the loop if successful
77+
except np.linalg.LinAlgError as ex: # Catch Singular Matrix or related issues
78+
alpha_in_loop *= 5 # Increase alpha exponentially to avoid singular matrix problem
79+
except Exception as ex: # Catch any other exception
80+
raise RuntimeError(ex) # Re-raise the exception for debugging or further handling
81+
else:
82+
raise RuntimeError("Logistic regression failed to converge after maximum retries.")
7183

7284
# Extract the results
7385
summary_frame = result.summary2().tables[1]
@@ -237,11 +249,16 @@ def iterative_shap_feature_reduction(
237249
shap_features, target, task, alpha
238250
)
239251

252+
if significance_df["t-value"].isna().all():
253+
collected_rows.extend(significance_df.to_dict('records'))
254+
break
255+
256+
240257
# Find the feature with the lowest t-value
241258
min_t_value_row = significance_df.loc[significance_df["t-value"].idxmin()]
242259

243260
# Remember this row (collect it in our list)
244-
collected_rows.append(min_t_value_row)
261+
collected_rows.append(min_t_value_row.to_dict())
245262

246263
# Drop the feature corresponding to the lowest t-value from shap_features
247264
feature_to_remove = min_t_value_row["feature name"]

0 commit comments

Comments
 (0)