#include <array>
#include <memory>
#include <list>
#include <lp_lib.h>
#include "utilities_cpp.h"
#include "lpSolve.h"
#include <optional>

inline cube Sigma_chol_t_draws (
	const uword n_variables,
	const uword n_posterior_draws,
	const cube& factor_loadings, //rows: variables, columns: factors, slices: draws
	const mat& U_vecs, //rows: entries of a (variables x variables)-upper triagonal matrix with ones on the diagonals, cols: draws
	const mat& logvar_T
) {
	const bool is_factor_model = factor_loadings.n_cols > 0;
	
	cube ret(n_variables, n_variables, n_posterior_draws, fill::none);
	for (uword r = 0; r < n_posterior_draws; r++) {
		// obtain the cholesky factorization of the covariance matrix
		mat Sigma_unused, Sigma_chol, facload_mat;
		vec u_vec;
		if(is_factor_model){
			facload_mat = factor_loadings.slice(r);
		} else {
			u_vec = U_vecs.col(r);
		}
		build_sigma(Sigma_unused, Sigma_chol, is_factor_model, facload_mat,
				logvar_T.col(r).as_row(), factor_loadings.n_cols,
				n_variables, u_vec, true);

		ret.slice(r) = Sigma_chol.t();
	}
	
	return ret;
}

// [[Rcpp::export]]
Rcpp::List compute_parameter_transformations (
	const arma::cube& reduced_coefficients, //rows: lagged variables + intercept, columns: variables, slices: draws
	const arma::cube& factor_loadings, //rows: variables, columns: factors, slices: draws
	const arma::mat& U_vecs, //rows: entries of a (variables x variables)-upper triagonal matrix with ones on the diagonals, cols: draws
	const arma::mat& logvar_T, //cols: draws
	const Rcpp::List& restrictions
) {
	const uword n_variables = reduced_coefficients.n_cols;
	const uword n_posterior_draws = reduced_coefficients.n_slices;
	const uword n_factors = factor_loadings.n_cols;
	
	const bool include_facload = restrictions.containsElementNamed("facload");
	const bool include_B0_inv_t = restrictions.containsElementNamed("B0_inv_t");
	const bool include_B0 = restrictions.containsElementNamed("B0");
	const bool include_structural_coeff = restrictions.containsElementNamed("structural_coeff");
	const bool include_long_run_ir = restrictions.containsElementNamed("restrictions_long_run_ir");
	
	Rcpp::List ret = Rcpp::List::create();
	
	// parameter transformations available for the factor model
	if (factor_loadings.n_elem > 0) {
		if (!(include_facload || include_long_run_ir)) return ret;
		cube factor_loadings_T (factor_loadings);
		for (uword r = 0; r < n_posterior_draws; r++) {
			factor_loadings_T.slice(r).each_row() %= exp(logvar_T.col(r).head(n_factors)/2).t();
		}
		if (include_facload) ret.push_back(factor_loadings_T, "facload");
		
		if(!include_long_run_ir) return ret;
		cube IR_inf(n_variables, n_factors, n_posterior_draws, fill::none);
		mat I = eye(n_variables, n_variables);
		for (uword r = 0; r < n_posterior_draws; r++) {
			mat sum_of_lags(n_variables, n_variables);
			for (uword l = 0; l + n_variables < reduced_coefficients.n_rows; l += n_variables) {
				sum_of_lags += reduced_coefficients.slice(r).rows(l, l + n_variables - 1);
			}
			IR_inf.slice(r) = inv((I - sum_of_lags).t()) * factor_loadings_T.slice(r);
		}
		ret.push_back(IR_inf, "IR_inf");
		return ret;
	}
	
	// parameter transformations available for the Cholesky model
	if (!(include_B0_inv_t || include_B0 || include_structural_coeff || include_long_run_ir)) return ret;
	const cube B0_inv_t = Sigma_chol_t_draws(
		n_variables,
		n_posterior_draws,
		factor_loadings,
		U_vecs,
		logvar_T
	);
	if (include_B0_inv_t) ret.push_back(B0_inv_t, "B0_inv_t");

	if (!(include_B0 || include_structural_coeff || include_long_run_ir)) return ret;
	cube B0(n_variables, n_variables, n_posterior_draws, fill::none);
	for (uword r = 0; r < n_posterior_draws; r++) {
		B0.slice(r) = inv(trimatu(B0_inv_t.slice(r).t()));
	}
	if (include_B0) ret.push_back(B0, "B0");

	if (!(include_structural_coeff || include_long_run_ir)) return ret;
	cube structural_coeff(reduced_coefficients);
	for (uword r = 0; r < n_posterior_draws; r++) {
		structural_coeff.slice(r) *= B0.slice(r);
	}
	if (include_structural_coeff) ret.push_back(structural_coeff, "structural_coeff");

	if (!include_long_run_ir) return ret;
	cube IR_inf(n_variables, n_variables, n_posterior_draws, fill::none);
	for (uword r = 0; r < n_posterior_draws; r++) {
		mat sum_of_lags(n_variables, n_variables);
		for (uword l = 0; l + n_variables < structural_coeff.n_rows; l += n_variables) {
			sum_of_lags += structural_coeff.slice(r).rows(l, l + n_variables - 1);
		}
		IR_inf.slice(r) = inv((B0.slice(r) - sum_of_lags).t());
	}
	ret.push_back(IR_inf, "IR_inf");
	return ret;
}

// see doi:10.1111/j.1467-937X.2009.00578.x section "5.2. A monetary SVAR"
// for an example how the matrices produced by this function should look like.
// In short: If spec[i] is zero for index i, add the unit vector e_i to
// the rows of the result.
std::optional<mat> construct_zero_restriction(const NumericMatrix::ConstColumn& spec) {
	const auto n = spec.size();
	mat zero_restriction_matrix(n, n);
	uword i = 0;
	for (R_xlen_t j = 0; j < n; j++) {
		if (spec[j] == 0) {
			zero_restriction_matrix(i++, j) = 1;
		}
	}
	if (i > 0) {
		return zero_restriction_matrix.head_rows(i);
	} else {
		// If spec contains no zero restrictions, return nothing:
		// altough perfectly valid here, Armadillo triggers CRAN's UBSAN checks
		// for matrices of size zero.
		return {};
	}
}

// if spec[i] > 0, add unit vector e_i to the rows of the result;
// if spec[i] < 0, add unit vector -e_i to the rows of the result.
std::optional<mat> construct_sign_restriction (const NumericMatrix::ConstColumn& spec) {
	const auto n = spec.size();
	mat sign_restriction_matrix(n, n);
	uword i = 0;
	for (R_xlen_t j = 0; j < n; j++) {
		if (spec[j] > 0) {
			sign_restriction_matrix(i++, j) = 1;
		}
		else if (spec[j] < 0) {
			sign_restriction_matrix(i++, j) = -1;
		}
		// else: spec[i] is either 0 or NA, so not a sign restriction
	}
	if (i > 0) {
		return sign_restriction_matrix.head_rows(i);
	} else {
		return {};
	}
}

class Solver {
	public:
	virtual void add_constraints(mat& constraints, int constr_type, double rhs) = 0;
	virtual vec solve() = 0;
	virtual void recycle() = 0;
	virtual void print() = 0;
	virtual ~Solver() = default;
};

class LPSolver : public Solver {
	using make_lp_func_ptr= lprec*(*)(int, int);
	// template<typename R, typename... T> using lp_func_ptr = R(*)(lprec*, T...);
	using get_statustext_func_ptr = char* (*)(lprec*, int);
	
	private:
	static constexpr double bigM = 10.0;
	
	//import remaining routines from the R package "lpSolveAPI"
	make_lp_func_ptr make_lp = (make_lp_func_ptr)R_GetCCallable("lpSolveAPI", "make_lp");
	get_statustext_func_ptr get_statustext = (get_statustext_func_ptr) R_GetCCallable("lpSolveAPI", "get_statustext");
	
	lprec *lp;
	uword n; //dimension of x
	ivec x_cols;
	ivec U_cols;
	ivec y_cols;
	int n_initial_rows;

	public:
	LPSolver(const uword dim_x) : n(dim_x) {
		lp = (*make_lp)(0, 3*n);
		if (lp == NULL) throw std::bad_alloc();
		
		setverbose(lp, IMPORTANT);
		setmaxim(lp);
		
		// we have variables x_i, U_i, y_i for i=1...n
		x_cols = regspace<ivec>(1, n);
		U_cols = x_cols + n;
		y_cols = U_cols + n;
		for (uword j = 0; j < n; j++) {
			std::string col_name("x");
			col_name += std::to_string(j+1);
			setcol_name(lp, x_cols[j], col_name.data());
			col_name[0] = 'U';
			setcol_name(lp, U_cols[j], col_name.data());
			col_name[0] = 'y';
			setcol_name(lp, y_cols[j], col_name.data());
		}
		
		setadd_rowmode(lp, TRUE);
				
		// maximize over sum of U
		vec ones(U_cols.n_elem, fill::ones);
		setobj_fnex(lp, U_cols.n_elem, ones.memptr(), U_cols.memptr());
		
		// set variable bounds
		for (uword j = 0; j < n; j++) {
			setbounds(lp, x_cols[j], -1, 1); //x_i in [-1, 1]
			setbounds(lp, U_cols[j], 0, 1); //U_i in [0, 1]
			setbinary(lp, y_cols[j], TRUE); //y_i in {0, 1}
		}
		
		//setup initial constraints
		//x_i must not be within the interval [-U_i, U_i]
		//y_i switches if x_i is right to the interval or left to the interval
		//=> in an optimal solution, U_i is the greatest possible value of |x_i|
		std::array<double, 3> left_constraint  = { 1.0 /* x_i */, 1.0 /*U_i*/, -bigM /*y*/};
		std::array<double, 3> right_constraint = {-1.0 /* x_i */, 1.0 /*U_i*/,  bigM /*y*/};
		for (uword j = 0; j < n; j++) {
			std::array<int, 3> cols = {x_cols[j], U_cols[j], y_cols[j]};
			addconstraintex(lp, cols.size(), left_constraint.data(), cols.data(), LE, 0);
			addconstraintex(lp, cols.size(), right_constraint.data(), cols.data(), LE, bigM);
		}
		
		//recycling will remove all constraints added by `add_constraint`
		n_initial_rows = getNrows(lp);
	}
	
	void add_constraints(mat& constraints, int constr_type, double rhs) override {
		constraints.each_row([&](rowvec& row) {
			const bool success = addconstraintex(lp, n, row.memptr(), x_cols.memptr(), constr_type, rhs);
			if (success == FALSE) {
				throw std::runtime_error("Could not add constraints");
			}
		});
	}
	
	vec solve() override {
		setadd_rowmode(lp, FALSE);
		int ret = lpsolve(lp);
		if(ret != 0) {
			print();
			throw std::logic_error((*get_statustext)(lp, ret));
		};
		vec lp_vars_store(3*n);
		getvariables(lp, lp_vars_store.memptr());
		return lp_vars_store.head_rows(n); // only return x
	}
	
	void recycle() override {
		// delete all non-initial constraints
		resizelp(lp, n_initial_rows, getNcolumns(lp));
		setadd_rowmode(lp, TRUE);
	}
	
	void print() override {
		setadd_rowmode(lp, FALSE);
		printlp(lp);
	}
		
	~LPSolver() {
		deletelp(lp);
	}
	
};

class RandomizedSolver : public Solver {
	private:
		mat zero_restrictions;
		mat sign_restrictions;
		const vec zero;
		const uword n_shocks;
		const double tol;
	public:
	RandomizedSolver(const uword n_shocks, const double tol) :
		zero_restrictions(mat(0, n_shocks)),
		sign_restrictions(mat(0, n_shocks)),
		zero(arma::zeros(n_shocks)),
		n_shocks(n_shocks), tol(tol)
	{}
	
	void add_constraints(mat& constraints, int constr_type, double rhs) override {
		if ( constr_type == EQ && rhs == 0) {
			zero_restrictions = join_vert(zero_restrictions, constraints);
		}
		else if ( constr_type == GE && rhs == 0) {
			sign_restrictions = join_vert(sign_restrictions, constraints);
		}
		else throw std::domain_error("Constraint type not implemented");
	};
	virtual vec solve() override {
		mat zero_restrictions_solution_space;
		if (zero_restrictions.n_rows == 0) {
			zero_restrictions_solution_space = eye(n_shocks, n_shocks);
		} else {
			zero_restrictions_solution_space = null(zero_restrictions, tol);
		}

		// draw random unit-length vector
		// here we are not allowed to draw until we find a vector that
		// satisfies the sign restrictions!
		// see https://doi.org/10.3982/ECTA14468 -> Ctrl+F "tempting"
		vec y(zero_restrictions_solution_space.n_cols);
		y.imbue(R::norm_rand);
		y = normalise(y);
		
		const vec q_candidate = zero_restrictions_solution_space*y;
		if (all(sign_restrictions * q_candidate >= 0-tol)) {
			return q_candidate;
		}
		else if (sign_restrictions.n_rows == 1) {
			//if there is exactly one sign restriction and it is not satisfied
			//switching the sign will satisfy the sign restriction
			return -q_candidate;
		}
		else return zero;
	};
	virtual void recycle() override {
		zero_restrictions = mat(0, zero_restrictions.n_cols);
		sign_restrictions = mat(0, sign_restrictions.n_cols);
	};
	virtual void print() override {
		Rcout << "Zero restrictions:" << endl << zero_restrictions <<
		         "Sign restrictions:" << endl << sign_restrictions;
	};
};

// [[Rcpp::export]]
Rcpp::List find_rotation_cpp(
	const arma::field<arma::cube>& parameter_transformations, //each field element: rows: transformation size, cols: variables, slices: draws
	const arma::field<Rcpp::NumericMatrix>& restriction_specs, //each field element: rows: transformation size, cols: variables
	const std::string& solver_option = "randomized", // "randomized" or "lp"
	const arma::uword randomized_max_rotations_per_sample = 2, // ignored when solver is "lp"
	const double tol = 1e-6
) {
	if (restriction_specs.n_elem != parameter_transformations.n_elem) {
		throw std::logic_error("Number of restrictions does not match number of parameter transformations.");
	}

	const uword n_variables = parameter_transformations(0).n_cols;
	const uword n_posterior_draws = parameter_transformations(0).n_slices;
	
	// rotation at index i belongs to sample with one-based-index rotation_sample_map[i]
	std::list<uword> rotation_sample_map;
	std::list<mat> rotations;
	// we do not know how many samples will be rejected, so the return value is constructed later
	//cube rotation(n_variables, n_variables, ?, fill::none);

	std::unique_ptr<Solver> solver;
	uword max_rotations_per_sample = 0;
	if (solver_option == "randomized") {
		max_rotations_per_sample = randomized_max_rotations_per_sample;
		solver.reset(new RandomizedSolver(n_variables, tol));
	}
	else if (solver_option == "lp") {
		max_rotations_per_sample = 1;
		solver.reset(new LPSolver(n_variables));
	}
	else throw std::domain_error("Unknown solver option");
	
	//field rows: tranformations, field cols: cols of the transformation
	//each field element: rows: number of restrictions, cols: transformation size
	field<mat> zero_restrictions(restriction_specs.n_elem, n_variables);
	field<mat> sign_restrictions(restriction_specs.n_elem, n_variables);
	uvec n_zero_restrictions(n_variables, fill::zeros);
	uvec n_sign_restrictions(n_variables, fill::zeros);
	const mat no_restrictions(0, n_variables);
	for (uword i = 0; i < restriction_specs.n_elem; i++) {
		for (uword j = 0; j < n_variables; j++) {
			const NumericMatrix::ConstColumn column_restriction_spec = restriction_specs(i).column(j);
			zero_restrictions(i, j) = construct_zero_restriction(column_restriction_spec).value_or(no_restrictions);
			sign_restrictions(i, j) = construct_sign_restriction(column_restriction_spec).value_or(no_restrictions);
			n_zero_restrictions(j) += zero_restrictions(i, j).n_rows; //rank = n_rows by construction!
			n_sign_restrictions(j) += sign_restrictions(i, j).n_rows; //rank = n_rows by construction!
		}
	}
		
	//iterate over columns in order or descending rank.
	//since each column must be orthogonal to the ones that came before,
	//we start with the column with the most restrictions	
	uvec col_order = sort_index(2 * n_zero_restrictions + n_sign_restrictions, "descend");
	
	for (uword r = 0; r < n_posterior_draws; r++) {
		if (r % 100 == 0) Rcpp::checkUserInterrupt();
		for (uword attempt = 0; attempt < max_rotations_per_sample; attempt++) {
		mat rotation(n_variables, n_variables, fill::none);
		bool reject_draw = false;
		for (uword j_index = 0; j_index < n_variables; j_index++) {
			const uword j = col_order[j_index];
			
			// the column j of the rotation matrix must be orthogonal to columns that came before
			if (j_index > 0) {
			  mat previous_columns = rotation.cols(col_order.head(j_index)).t();
			  solver->add_constraints(previous_columns, EQ, 0);
			}
			
			if (n_zero_restrictions(j) > 0) {
				for (uword i = 0; i < parameter_transformations.n_elem; i++) {
					mat zero_constraints = zero_restrictions(i, j) * parameter_transformations(i).slice(r);
					solver->add_constraints(zero_constraints, EQ, 0);
				}
			}
			
			if (n_sign_restrictions(j) > 0) {
				for (uword i = 0; i < parameter_transformations.n_elem; i++) {
					mat sign_constraints = sign_restrictions(i, j) * parameter_transformations(i).slice(r);
					solver->add_constraints(sign_constraints, GE, 0);
				}
			}
			
			const vec p_j = solver->solve().clean(tol);
			solver->recycle();
			if (p_j.is_zero()) {
				//zero was the optimal solution
				//reject this draw.
				reject_draw = true;
				break;
			}
			rotation.col(j) = normalise(p_j);
		}
		if (!reject_draw) {
			rotation_sample_map.push_back(r+1); //one-based index because R
			rotations.push_back(rotation);
		}
		}
	}
	
	// convert `rotations` to a cube
	cube is_it_about_my_cube(n_variables, n_variables, rotations.size());
	{
		uword i = 0;
		for (mat rot : rotations) {
			is_it_about_my_cube.slice(i++) = rot;
		}
	}
	
	Rcpp::List ret = Rcpp::List::create();
	ret["rotation"] = is_it_about_my_cube;
	ret["rotation_sample_map"] = rotation_sample_map;
	return ret;
}

inline void shift_and_insert(
	mat& X, //the columns of X should be y1,y2,y3, y1.l1,y2.l1,y3.l1,...,1
	const mat& new_y //what to insert in y1,y2,y3
) {
	for (uword i = X.n_cols-2; new_y.n_cols <= i; i--) {
		X.col(i) = X.col(i-new_y.n_cols);
	}
	X.head_cols(new_y.n_cols) = new_y;
}

// [[Rcpp::export]]
arma::field<arma::cube> irf_cpp(
	const arma::cube& coefficients, //rows: lagged variables + intercept, columns: variables, slices: draws
	const arma::cube& factor_loadings, //rows: variables, columns: factors, slices: draws
	const arma::mat& U_vecs, //rows: entries of a (variables x variables)-upper triagonal matrix with ones on the diagonals, cols: draws
	const arma::mat& logvar_t, //rows: log variances, cols: draws
	const arma::mat& shocks, //rows: dim shock, cols: shocks
	const arma::uword ahead, //how far to predict ahead
	const Rcpp::Nullable<Rcpp::NumericMatrix> rotation_ = R_NilValue //rows: variables, cols: dim shock, slices: draws
) {
	const uword n_shocks = shocks.n_cols;
	const uword n_variables = coefficients.n_cols;
	const uword n_posterior_draws = coefficients.n_slices;
	const bool is_factor_model = factor_loadings.n_cols > 0;
	const bool is_cholesky_model = U_vecs.n_cols > 0;
	const uvec upper_indices = trimatu_ind(size(n_variables, n_variables), 1);
	cube rotation;
	if (rotation_.isNotNull()) {
		rotation = Rcpp::as<cube>(rotation_);
	}

	field<cube> ret(n_posterior_draws);
	for (uword r = 0; r < n_posterior_draws; r++) {
		cube irf(n_variables, n_shocks, ahead+1);
		
		//compute the responses to the shocks at t=0
		mat rotated_shocks = rotation.n_slices > 0 ? rotation.slice(r) * shocks : shocks;
		if (is_factor_model) {
			vec sqrt_V_t = exp(logvar_t.col(r).head(factor_loadings.n_cols) / 2);
			rotated_shocks.each_col() %= sqrt_V_t;
			irf.slice(0) = factor_loadings.slice(r) * rotated_shocks;
		}
		else if (is_cholesky_model) {
		  mat U(n_variables, n_variables, fill::eye);
			U(upper_indices) = U_vecs.col(r);

			vec sqrt_D_t = exp(logvar_t.col(r) / 2.0);
			rotated_shocks.each_col() %= sqrt_D_t;
			irf.slice(0) = solve(trimatl(U.t()), rotated_shocks);
		}
		else {
		  irf.slice(0) = rotated_shocks;
		}
		
		// compute how the shocks propagate using the reduced form coeffs
		mat current_predictors(n_shocks, coefficients.n_rows, fill::zeros);
		current_predictors.head_cols(n_variables) = irf.slice(0).t(); //set lag zero
		for (uword t = 1; t <= ahead; t++) {
			mat new_predictiors = current_predictors * coefficients.slice(r);
			irf.slice(t) = new_predictiors.t();
			// shift everything and make predictions the new predictors at lag zero
			shift_and_insert(current_predictors, new_predictiors);
		}
		ret(r) = irf;
	}
	
	return ret;
}

// [[Rcpp::export]]
arma::ivec irf_bayes_optimal_order(arma::field<arma::cube>& irf) {
	const uword n_posterior_draws = irf.n_elem;
	vec losses(n_posterior_draws, fill::zeros);
	for (uword r = 0; r < n_posterior_draws; r++) {
		if (r % 100 == 0) Rcpp::checkUserInterrupt();
		for (uword r_other = 0; r_other < r; r_other++) {
			const double loss = accu(abs(irf(r) - irf(r_other)));
			losses(r) += loss;
			losses(r_other) += loss;
		}
	}
	return conv_to<ivec>::from(sort_index(losses, "ascend"));
}

// [[Rcpp::export]]
arma::cube irf_from_true_parameters(
	arma::mat true_structural_matrix,
	arma::mat true_reduced_coeff,
	arma::uword ahead
) {
	const uword n_variables = true_structural_matrix.n_rows;
	const uword n_shocks = true_structural_matrix.n_cols;
	
	cube irf(n_variables, n_shocks, ahead+1);
	irf.slice(0) = inv(true_structural_matrix).t();
	
	mat current_predictors(n_shocks, true_reduced_coeff.n_rows, fill::zeros);
	current_predictors.head_cols(n_variables) = irf.slice(0).t();
	for (uword t = 1; t <= ahead; t++) {
		mat new_predictiors = current_predictors * true_reduced_coeff;
		irf.slice(t) = new_predictiors.t();
		shift_and_insert(current_predictors, new_predictiors);
	}
	return irf;
}
