|
4 | 4 | import statsmodels.api as sm |
5 | 5 | import scipy.stats as stats |
6 | 6 | import shap |
| 7 | +import numpy as np |
7 | 8 |
|
8 | 9 |
|
9 | 10 | def create_shap_features( |
@@ -65,9 +66,20 @@ def binary_classifier_significance( |
65 | 66 | # Add a constant to the features for the intercept in logistic regression |
66 | 67 | shap_features_with_constant = sm.add_constant(shap_features) |
67 | 68 |
|
68 | | - # Fit the logistic regression model that will generate confidence intervals |
69 | | - logit_model = sm.Logit(target, shap_features_with_constant) |
70 | | - result = logit_model.fit_regularized(disp=False, alpha=alpha) |
| 69 | + alpha_in_loop = alpha |
| 70 | + max_retries = 10 # Set a maximum number of retries to avoid infinite loops |
| 71 | + for attempt in range(max_retries): |
| 72 | + try: |
| 73 | + # Fit the logistic regression model that will generate confidence intervals |
| 74 | + logit_model = sm.Logit(target, shap_features_with_constant) |
| 75 | + result = logit_model.fit_regularized(disp=False, alpha=alpha_in_loop) |
| 76 | + break # Exit the loop if successful |
| 77 | + except np.linalg.LinAlgError as ex: # Catch Singular Matrix or related issues |
| 78 | + alpha_in_loop *= 5 # Increase alpha exponentially to avoid singular matrix problem |
| 79 | + except Exception as ex: # Catch any other exception |
| 80 | + raise RuntimeError(ex) # Re-raise the exception for debugging or further handling |
| 81 | + else: |
| 82 | + raise RuntimeError("Logistic regression failed to converge after maximum retries.") |
71 | 83 |
|
72 | 84 | # Extract the results |
73 | 85 | summary_frame = result.summary2().tables[1] |
@@ -237,11 +249,16 @@ def iterative_shap_feature_reduction( |
237 | 249 | shap_features, target, task, alpha |
238 | 250 | ) |
239 | 251 |
|
| 252 | + if significance_df["t-value"].isna().all(): |
| 253 | + collected_rows.extend(significance_df.to_dict('records')) |
| 254 | + break |
| 255 | + |
| 256 | + |
240 | 257 | # Find the feature with the lowest t-value |
241 | 258 | min_t_value_row = significance_df.loc[significance_df["t-value"].idxmin()] |
242 | 259 |
|
243 | 260 | # Remember this row (collect it in our list) |
244 | | - collected_rows.append(min_t_value_row) |
| 261 | + collected_rows.append(min_t_value_row.to_dict()) |
245 | 262 |
|
246 | 263 | # Drop the feature corresponding to the lowest t-value from shap_features |
247 | 264 | feature_to_remove = min_t_value_row["feature name"] |
|
0 commit comments