For:
We will build a SARIMAX model for monthly data on an industrial production time series for the 1988-2017 period. As illustrated in the first section on analytical tools, the data has been log-transformed, and we are using seasonal (lag-12) differences. We estimate the model for a range of both ordinary and conventional AR and MA parameters using a rolling window of 10 years of training data, and evaluate the RMSE of the 1-step-ahead forecast.
Finding the optimal number of lags
from joblib import Parallel, delayed
from tqdm import tqdm
train_size = 120  # 10 years of training data
test_set = industrial_production_log_diff.iloc[train_size:]
def fit_and_predict(params, train_size, industrial_production_log_diff):
    p1, q1, p2, q2 = params
    preds = test_set.copy().to_frame('y_true').assign(y_pred=np.nan)
    aic, bic = [], []
    if p1 == 0 and q1 == 0:
        return None
    convergence_error = stationarity_error = 0
    y_pred = []
    for i, T in enumerate(range(train_size, len(industrial_production_log_diff))):
        train_set = industrial_production_log_diff.iloc[T-train_size:T]
        try:
            with warnings.catch_warnings():
                warnings.filterwarnings("ignore")
                model = SARIMAX(endog=train_set.values,
                                order=(p1, 0, q1),
                                seasonal_order=(p2, 0, q2, 12)).fit(disp=0)
        except LinAlgError:
            convergence_error += 1
            continue
        except ValueError:
            stationarity_error += 1
            continue
        preds.iloc[i, 1] = model.forecast(steps=1)[0]
        aic.append(model.aic)
        bic.append(model.bic)
    preds.dropna(inplace=True)
    if preds.empty:
        return None
    mse = mean_squared_error(preds.y_true, preds.y_pred)
    return (p1, q1, p2, q2), [np.sqrt(mse),
                              preds.y_true.sub(preds.y_pred).pow(2).std(),
                              np.mean(aic),
                              np.std(aic),
                              np.mean(bic),
                              np.std(bic),
                              convergence_error,
                              stationarity_error]
# Parallel processing
results = Parallel(n_jobs=-1)(delayed(fit_and_predict)(params, train_size, industrial_production_log_diff) for params in tqdm(params))
# Filter out None results
results = {k: v for k, v in results if v is not None}
This version introduces parallel processing. 100% in ~14 minutes.
The following are some of the cells and their respective outputs just to show that the refactored code will not break anything that follows.
sarimax_results.nsmallest(5, columns='RMSE')
sarimax_results[['RMSE', 'AIC', 'BIC']].sort_values('RMSE').head()

sns.jointplot(y='RMSE', x='BIC', data=sarimax_results[['RMSE', 'BIC']].rank());
.
.
.
print(best_model.summary())
plot_correlogram(pd.Series(best_model.resid),                  lags=20,                  title=f'SARIMAX ({p1}, 0, {q1}) x ({p2}, 0, {q2}, 12) | Model Diagnostics')
Hope this helps.