/*  Deadwood: Outlier Detection via Minimum Spanning Trees
 *
 *  Copyleft (C) 2025-2026, Marek Gagolewski <https://www.gagolewski.com>
 *
 *  This program is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU Affero General Public License
 *  Version 3, 19 November 2007, published by the Free Software Foundation.
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 *  GNU Affero General Public License Version 3 for more details.
 *  You should have received a copy of the License along with this program.
 *  If this is not the case, refer to <https://www.gnu.org/licenses/>.
 */


#ifndef __c_deadwood_h
#define __c_deadwood_h


#include "c_common.h"
#include "c_argfuns.h"
#include "c_kneedle.h"
#include <stdexcept>
#include <algorithm>


/*! Reorders x w.r.t. a factor c
 *
 * y[ind[j]],...,y[ind[j+1]-1] give all x[i]s, in their original relative order,
 * for which c[i]==j.
 *
 * Elements corresponding to c[i] < 0 are put at the start of y.
 * c[i] >= k is disallowed.
 *
 * @param x [in] array of size n
 * @param n
 * @param c [in] array of size n with elements in {...,0,1,..,k-1}
 * @param k
 * @param y [out] array of size n
 * @param ind [out] array of size k+1
 */
template <class FLOAT>
void Csort_groups(
    const FLOAT* x, Py_ssize_t n, const Py_ssize_t* c, Py_ssize_t k,
    FLOAT* y, Py_ssize_t* ind
) {
    for (Py_ssize_t j=0; j<=k; ++j) ind[j] = 0;

    for (Py_ssize_t i=0; i<n; ++i) {
        DEADWOOD_ASSERT(c[i] < k);
        if (c[i] < 0)
            ++ind[0];
        else if (c[i] < k)
            ++ind[c[i]+1];
    }

    Py_ssize_t u = ind[0];
    ind[0] = 0;
    for (Py_ssize_t j=1; j<=k; ++j) {
        Py_ssize_t v = ind[j];
        ind[j] = u;  // sum of the original ind[0]..ind[j-1]
        u += v;
    }

    for (Py_ssize_t i=0; i<n; ++i) {
        if (c[i] < 0)
            y[ind[0]++] = x[i];
        else
            y[ind[c[i]+1]++] = x[i];
    }
}


/*! Identifies which MST edges must be skipped to obtain a forest whose
 *  connected components match a given partition.  If this is not possible,
 *  a more fine-grained split is generated.
 *
 *  This is as easy as finding all MST edges {u,v} for which c[u]≠c[v].
 *
 *  @param mst_i c_contiguous matrix of size m*2,
 *     where {mst_i[k,0], mst_i[k,1]} specifies the k-th (undirected) edge
 *     in the spanning tree
 *  @param m number of rows in mst_i (edges)
 *  @param n length of c and the number of vertices in the spanning tree
 *  @param c [in] array of length n, where
 *      c[i] denotes the cluster ID of the i-th object
 *  @param skip [out] array of length m, indicating which edges
 *      of the tree must be skipped to create a subpartition of c
 *
 *  @return s number of edges in skip;  ideally, s=k-1, where k is the
 *      number of classes in c
 */
Py_ssize_t Cget_skip_edges(
    const Py_ssize_t* mst_i,  // size m [in]
    Py_ssize_t m,
    const Py_ssize_t* c,  // size n [in]
    Py_ssize_t n,
    bool* skip  // size m [out]
) {
    Py_ssize_t s = 0;

    for (Py_ssize_t i=0; i<m; ++i) {
        Py_ssize_t u = mst_i[2*i+0];
        Py_ssize_t v = mst_i[2*i+1];
        DEADWOOD_ASSERT(u >= 0 && u < n);
        DEADWOOD_ASSERT(v >= 0 && v < n);
        if (c[u] != c[v]) {
            s++;
            skip[i] = true;
        }
        else
            skip[i] = false;
    }

    return s;
}

/*! Decode indexes based on a skip array.
 *
 * If `skip=[False, True, False, False, True, False, False]`,
 * then the indexes in `ind` are mapped in such a way that:
 * 0 → 0,
 * 1 → 2,
 * 2 → 3,
 * 3 → 5,
 * 4 → 6.
 *
 * This function might be useful if we apply a method on `X[~skip,:]`
 * (a subset of rows in `X`), obtain a vector of indexes `ind` relative to
 * the indexes of rows in `X[~skip,:]` as a result, and wish to translate `ind`
 * back to the original row space of `X[:,:]`.
 *
 * For instance, `unskip_indexes([0, 2, 1], [True, False, True, False, False])`
 * yields `[1, 4, 3]`.
 *
 * @param ind [in/out] array of m indexes in 0..k-1 to translate
 * @param m size of ind
 * @param skip Boolean array of size n with k elements equal to False
 * @param n size of skip
 */
void Cunskip_indexes(
    Py_ssize_t* ind, Py_ssize_t m,
    const bool* skip, Py_ssize_t n
) {
    if (m <= 0) return;
    DEADWOOD_ASSERT(n > 0);

    std::vector<Py_ssize_t> o(n);  // actually, k needed
    Py_ssize_t k = 0;
    for (Py_ssize_t i=0; i<n; ++i) {
        if (!skip[i]) o[k++] = i;
    }

    for (Py_ssize_t i=0; i<m; ++i) {
        DEADWOOD_ASSERT(ind[i] >= 0 && ind[i] < k)
        ind[i] = o[ind[i]];
    }

    // std::vector<Py_ssize_t> o(m);
    // Cargsort(o.data(), ind, m, false);
    //
    // Py_ssize_t j = 0;
    // Py_ssize_t k = 0;
    // for (Py_ssize_t i=0; i<n; ++i) {
    //     if (skip[i]) continue;
    //
    //     if (ind[o[k]] == j) {
    //         ind[o[k]] = i;
    //         k++;
    //
    //         if (k == m) return;
    //     }
    //
    //     j++;
    // }
    //
    // throw std::domain_error("index to translate out of range");
}


/*! Encode indexes based on a skip array.
 *
 * If `skip=[False, True, False, False, True, False, False]`,
 * then the indexes in `ind` are mapped in such a way that:
 * 0 ← 0,
 * 1 ← 2,
 * 2 ← 3,
 * 3 ← 5,
 * 4 ← 6,
 * i.e., the indexes for which `skip` is False are mapped
 * to consecutive integers.  All other indexes are assigned the value -1.
 *
 * For instance, `skip_indexes([1, 4, 3], [True, False, True, False, False])`
 * yields `[0, 2, 1]`.
 *
 * @param ind [in/out] array of m indexes in 0..n-1 to translate
 * @param m size of ind
 * @param skip Boolean array of size n
 * @param n size of skip
 */
void Cskip_indexes(
    Py_ssize_t* ind, Py_ssize_t m,
    const bool* skip, Py_ssize_t n
) {
    if (m <= 0) return;
    DEADWOOD_ASSERT(n > 0);

    std::vector<Py_ssize_t> o(n);
    Py_ssize_t k = 0;
    for (Py_ssize_t i=0; i<n; ++i) {
        if (skip[i]) o[i] = -1;
        else o[i] = (k++);
    }

    for (Py_ssize_t i=0; i<m; ++i) {
        DEADWOOD_ASSERT(ind[i] >= 0 && ind[i] < n)
        ind[i] = o[ind[i]];
    }
}


/** Count the number of non-zero elements in a Boolean array x of length n
 */
Py_ssize_t Csum_bool(const bool* x, Py_ssize_t n)
{
    Py_ssize_t s = 0;
    for (Py_ssize_t i=0; i<n; ++i)
        if (x[i]) s++;
    return s;
}


/*! Compute the degree of each vertex in an undirected graph
 *  over a vertex set {0,...,n-1}.
 *
 *
 * @param ind c_contiguous matrix of size m*2,
 *     where {ind[i,0], ind[i,1]} is the i-th edge with ind[i,j] < n
 * @param m number of edges (rows in ind)
 * @param n number of vertices
 * @param deg [out] array of size n, where
 *     deg[i] will give the degree of the i-th vertex.
 */
void Cgraph_vertex_degrees(
    const Py_ssize_t* ind,
    const Py_ssize_t m,
    const Py_ssize_t n,
    Py_ssize_t* deg /*out*/
) {
    for (Py_ssize_t i=0; i<n; ++i)
        deg[i] = 0;

    for (Py_ssize_t i=0; i<m; ++i) {
        Py_ssize_t u = ind[2*i+0];
        Py_ssize_t v = ind[2*i+1];

        if (u < 0 || v < 0)
            throw std::domain_error("All elements must be >= 0");
        else if (u >= n || v >= n)
            throw std::domain_error("All elements must be < n");
        else if (u == v)
            throw std::domain_error("Self-loops are not allowed");

        deg[u]++;
        deg[v]++;
    }
}


/*! Compute the incidence list of each vertex in an undirected graph
 *  over a vertex set {0,...,n-1}.
 *
 *  @param ind c_contiguous matrix of size m*2,
 *      where {ind[i,0], ind[i,1]} is the i-th edge with ind[i,j] < n
 *  @param m number of edges (rows in ind)
 *  @param n number of vertices
 *  @param cumdeg [out] array of size n+1, where cumdeg[i+1] the sum of the first i vertex degrees
 *  @param inc [out] array of size 2*m; inc[cumdeg[i]]..inc[cumdeg[i+1]-1] gives the edges incident on the i-th vertex
 */
void Cgraph_vertex_incidences(
    const Py_ssize_t* ind,
    const Py_ssize_t m,
    const Py_ssize_t n,
    Py_ssize_t* cumdeg,
    Py_ssize_t* inc
) {
    cumdeg[0] = 0;
    Cgraph_vertex_degrees(ind, m, n, cumdeg+1);

    Py_ssize_t cd = 0;
    for (Py_ssize_t i=1; i<n+1; ++i) {
        Py_ssize_t this_deg = cumdeg[i];
        cumdeg[i] = cd;
        cd += this_deg;
    }
    // that's not it yet; cumdeg is adjusted below


    for (Py_ssize_t e=0; e<m; ++e) {
        Py_ssize_t u = ind[2*e+0];
        Py_ssize_t v = ind[2*e+1];

        *(inc+cumdeg[u+1]) = e;
        ++(cumdeg[u+1]);

        *(inc+cumdeg[v+1]) = e;
        ++(cumdeg[v+1]);
    }

    DEADWOOD_ASSERT(cumdeg[0] == 0);
    DEADWOOD_ASSERT(cumdeg[n] == 2*m);


// #ifdef DEBUG
//     cumdeg = 0;
//     inc[0] = data;
//     for (Py_ssize_t i=0; i<n; ++i) {
//         DEADWOOD_ASSERT(inc[i] == data+cumdeg);
//         cumdeg += deg[i];
//     }
// #endif
}




/* ************************************************************************** */


class CMSTProcessorBase
{
protected:
    const Py_ssize_t* mst_i;  // size m*2, elements in [0,n)
    const Py_ssize_t m;  // preferably == n-1; number of edges in mst_i
    const Py_ssize_t n;  // number of vertices

    Py_ssize_t* c;  // nullable or length n; cluster IDs of the vertices

    const Py_ssize_t* cumdeg;  // nullable or length n+1
    const Py_ssize_t* inc;     // nullable or length 2*m
    const bool* skip_edges;    // nullable or length m

    std::vector<Py_ssize_t> _cumdeg;  // data buffer for cumdeg (optional)
    std::vector<Py_ssize_t> _inc;     // data buffer for inc (optional)


public:

    CMSTProcessorBase(
        const Py_ssize_t* mst_i,
        const Py_ssize_t m,
        const Py_ssize_t n,
        Py_ssize_t* c=nullptr,
        const Py_ssize_t* cumdeg=nullptr,
        const Py_ssize_t* inc=nullptr,
        const bool* skip_edges=nullptr
    ) :
        mst_i(mst_i), m(m), n(n), c(c),
        cumdeg(cumdeg), inc(inc), skip_edges(skip_edges)
    {
        if (!cumdeg) {
            DEADWOOD_ASSERT(!inc);
            _cumdeg.resize(n+1);
            _inc.resize(2*m);
            Cgraph_vertex_incidences(mst_i, m, n, _cumdeg.data(), _inc.data());
            this->cumdeg = _cumdeg.data();
            this->inc = _inc.data();
        }
        else {
            DEADWOOD_ASSERT(inc);
        }
    }
};



/* ************************************************************************** */



/** See Cmst_get_cluster_sizes below.
 */
class CMSTClusterSizeGetter : public CMSTProcessorBase
{
private:

    Py_ssize_t max_k;
    Py_ssize_t* s;  // NULL or of size max_k >= k, where k is the number of clusters
    Py_ssize_t k;   // the number of connected components identified


    Py_ssize_t visit(Py_ssize_t v, Py_ssize_t e)
    {
        Py_ssize_t w;

        if (e < 0) {
            w = v;
        }
        else if (skip_edges && skip_edges[e])
            return 0;
        else {
            Py_ssize_t iv = (Py_ssize_t)(mst_i[2*e+1]==v);
            w = mst_i[2*e+(1-iv)];
        }

        DEADWOOD_ASSERT(c[w] < 0);
        c[w] = k;

        Py_ssize_t curs = 1;

        for (const Py_ssize_t* pe = inc+cumdeg[w]; pe != inc+cumdeg[w+1]; pe++) {
            if (*pe != e) curs += visit(w, *pe);
        }

        return curs;
    }


public:
    CMSTClusterSizeGetter(
        const Py_ssize_t* mst_i,
        Py_ssize_t m,
        Py_ssize_t n,
        Py_ssize_t* c,
        Py_ssize_t max_k,
        Py_ssize_t* s=nullptr,
        const Py_ssize_t* cumdeg=nullptr,
        const Py_ssize_t* inc=nullptr,
        const bool* skip_edges=nullptr
    ) : CMSTProcessorBase(mst_i, m, n, c, cumdeg, inc, skip_edges), max_k(max_k), s(s), k(-1)
    {
        DEADWOOD_ASSERT(this->c);
        DEADWOOD_ASSERT(this->cumdeg);
        DEADWOOD_ASSERT(this->inc);
    }


    Py_ssize_t process()
    {
        for (Py_ssize_t v=0; v<n; ++v) c[v] = -1;
        for (Py_ssize_t i=0; i<max_k; ++i) s[i] = 0;

        k = 0;
        for (Py_ssize_t v=0; v<n; ++v) {
            if (c[v] >= 0) continue;  // already visited -> skip

            if (s) {
                DEADWOOD_ASSERT(k<max_k);
                s[k] = visit(v, -1);
            }
            else
                visit(v, -1);

            k++;
        }

        return k;
    }

};


/*! Labels connected components in a spanning forest (where skip_edges
 *  designate the edges omitted from the tree) and fetch their sizes
 *
 *  @param mst_i c_contiguous matrix of size m*2,
 *     where {mst_i[k,0], mst_i[k,1]} specifies the k-th (undirected) edge
 *     in the spanning tree
 *  @param m number of rows in mst_i (edges)
 *  @param n length of c and the number of vertices in the spanning tree
 *  @param c [out] array of length n, where
 *      c[i] denotes the cluster ID (in {0, 1, ..., k-1} for some k)
 *      of the i-th object, i=0,...,n-1
 *  @param max_k the actual size of s (a safeguard)
 *  @param s [out] array of length max_k >= k, where k is the number of connected
 *      components in the forest; s[i] gives the size of the i-th cluster;
 *      pass NULL to get only the cluster labels;
 *      obviously, k<=n; e.g., if m==n-1, then k=sum(skip_edges)+1
 *  @param mst_cumdeg an array of length n+1 or NULL; see Cgraph_vertex_incidences
 *  @param mst_inc an array of length 2*m or NULL; see Cgraph_vertex_incidences
 *  @param mst_skip Boolean array of length m or NULL; indicates the edges to skip
 */
Py_ssize_t Cmst_cluster_sizes(
    const Py_ssize_t* mst_i,
    Py_ssize_t m,
    Py_ssize_t n,
    Py_ssize_t* c,
    Py_ssize_t max_k=0,
    Py_ssize_t* s=nullptr,
    const Py_ssize_t* mst_cumdeg=nullptr,
    const Py_ssize_t* mst_inc=nullptr,
    const bool* mst_skip=nullptr
) {
    CMSTClusterSizeGetter get(mst_i, m, n, c, max_k, s, mst_cumdeg, mst_inc, mst_skip);
    return get.process();  // modifies c in place
}


/* ************************************************************************** */




/** See Cmst_impute_missing_labels below.
 */
class CMSTMissingLabelsImputer : public CMSTProcessorBase
{
private:

    void visit(Py_ssize_t v, Py_ssize_t e)
    {
        if (skip_edges && skip_edges[e]) return;

        Py_ssize_t iv = (Py_ssize_t)(mst_i[2*e+1]==v);
        Py_ssize_t w = mst_i[2*e+(1-iv)];

        DEADWOOD_ASSERT(c[v] >= 0);
        DEADWOOD_ASSERT(c[w] < 0);

        c[w] = c[v];

        for (const Py_ssize_t* pe = inc+cumdeg[w]; pe != inc+cumdeg[w+1]; pe++) {
            if (*pe != e) visit(w, *pe);
        }
    }


public:
    CMSTMissingLabelsImputer(
        const Py_ssize_t* mst_i,
        Py_ssize_t m,
        Py_ssize_t n,
        Py_ssize_t* c,
        const Py_ssize_t* cumdeg=nullptr,
        const Py_ssize_t* inc=nullptr,
        const bool* skip_edges=nullptr
    ) : CMSTProcessorBase(mst_i, m, n, c, cumdeg, inc, skip_edges)
    {
        DEADWOOD_ASSERT(this->c);
        DEADWOOD_ASSERT(this->cumdeg);
        DEADWOOD_ASSERT(this->inc);
    }


    void process()
    {
        for (Py_ssize_t v=0; v<n; ++v) {
            if (c[v] < 0) continue;

            for (const Py_ssize_t* pe = inc+cumdeg[v]; pe != inc+cumdeg[v+1]; pe++) {
                if (skip_edges && skip_edges[*pe]) continue;

                Py_ssize_t iv = (Py_ssize_t)(mst_i[2*(*pe)+1]==v);
                Py_ssize_t w = mst_i[2*(*pe)+(1-iv)];

                if (c[w] < 0) {  // descend into this branch to impute missing values
                    visit(v, *pe);
                }
            }
        }
    }

};


/*! Impute missing labels in all tree branches.
 *  All nodes in branches with class ID of -1 will be assigned their parent node's class.
 *
 *  @param mst_i c_contiguous matrix of size m*2,
 *     where {mst_i[k,0], mst_i[k,1]} specifies the k-th (undirected) edge
 *     in the spanning tree
 *  @param m number of rows in mst_i (edges)
 *  @param n length of c and the number of vertices in the spanning tree
 *  @param c [in/out] c_contiguous vector of length n, where
 *      c[i] denotes the cluster ID (in {-1, 0, 1, ..., k-1} for some k)
 *      of the i-th object, i=0,...,n-1.  Class -1 represents missing values
 *      to be imputed
 *  @param mst_cumdeg an array of length n+1 or NULL; see Cgraph_vertex_incidences
 *  @param mst_inc an array of length 2*m or NULL; see Cgraph_vertex_incidences
 *  @param mst_skip Boolean array of length m or NULL; indicates the edges to skip
 */
void Cmst_label_imputer(
    const Py_ssize_t* mst_i,
    Py_ssize_t m,
    Py_ssize_t n,
    Py_ssize_t* c,
    const Py_ssize_t* mst_cumdeg=nullptr,
    const Py_ssize_t* mst_inc=nullptr,
    const bool* mst_skip=nullptr
) {
    CMSTMissingLabelsImputer imp(mst_i, m, n, c, mst_cumdeg, mst_inc, mst_skip);
    imp.process();  // modifies c in place
}


/* ************************************************************************** */



template<class FLOAT>
void Cget_contamination(
    const FLOAT* mst_d,
    Py_ssize_t m,
    FLOAT max_contamination,
    FLOAT ema_dt,
    FLOAT& contamination,
    Py_ssize_t& threshold_index
) {
    if (max_contamination <= 0.0) {
        contamination = -max_contamination;
        threshold_index =  int(m*(1.0-contamination));
    }
    else {
        Py_ssize_t shift = (int)(m*(1.0-max_contamination));
        Py_ssize_t elbow_index = Ckneedle_increasing(mst_d+shift, m-shift, true, ema_dt);
        if (elbow_index == 0) {
            threshold_index = m;
            contamination = 0.0;
        }
        else {
            threshold_index = shift+elbow_index+1;
            contamination = (m-threshold_index)/(FLOAT)(m+1);
        }
    }
}


/*! The Deadwood outlier detection algorithm
 *
 *  @param mst_d size m - edge weights
 *  @param mst_i c_contiguous matrix of size m*2,
 *     where {mst_i[k,0], mst_i[k,1]} specifies the k-th (undirected) edge
 *     in the spanning tree
 *  @param mst_cut array of size k-1; indexes of cut edges defining a spanning
 *     forest with k connected components
 *  @param m number of rows in mst_i (edges)
 *  @param n length of c and the number of vertices in the spanning tree
 *  @param k number of initial clusters
 *  @param c [out] array of length n, c[i]==1 marks an outlier
 *         and c[i]==0 denotes an inlier
 *  @param max_debris_size connected components of size <= max_debris_size will
 *         be treated as outliers
 *  @param max_contamination maximal contamination level;
 *         negative values will be used as actual contamination levels
 *  @param ema_dt controls the exponential moving average smoothing parameter
 *         alpha = 1-exp(-dt) (in elbow detection)
 *  @param contamination [out] array of length k;
 *         detected contamination levels in each cluster
 *  @param mst_cumdeg an array of length n+1 or NULL; see Cgraph_vertex_incidences
 *  @param mst_inc an array of length 2*m or NULL; see Cgraph_vertex_incidences
 */
template <class FLOAT>
void Cdeadwood(
    const FLOAT* mst_d,  // size m [in]
    const Py_ssize_t* mst_i,  // size m [in]
    const Py_ssize_t* mst_cut,  // size k-1 [in]
    Py_ssize_t m,
    Py_ssize_t n,
    Py_ssize_t k,
    FLOAT max_contamination,
    FLOAT ema_dt,
    Py_ssize_t max_debris_size,
    FLOAT* contamination,  // size k [out]
    Py_ssize_t* c,  // size n [out]
    const Py_ssize_t* mst_cumdeg=nullptr,
    const Py_ssize_t* mst_inc=nullptr
) {
    DEADWOOD_ASSERT(k >= 1 && k <= n);
    DEADWOOD_ASSERT(m == n-1);
    DEADWOOD_ASSERT(n > 1);

    std::vector<Py_ssize_t> sizes(n);  // upper bound for the number of clusters
    std::basic_string<bool> mst_skip(m, false);  // std::vector<bool> has no data()

    CMSTClusterSizeGetter size_getter(mst_i, m, n, c, n, sizes.data(), mst_cumdeg, mst_inc, mst_skip.data());

    DEADWOOD_ASSERT(max_contamination >= -1.0 && max_contamination <= 1.0);

    if (k == 1) {
        Py_ssize_t threshold_index = -1;
        Cget_contamination(
            mst_d, m, max_contamination, ema_dt,
            /*out*/contamination[0], /*out*/threshold_index
        );

        // DEADWOOD_PRINT("%d-%d\n", threshold_index, m);
        DEADWOOD_ASSERT(threshold_index >= 0);
        for (Py_ssize_t i=threshold_index; i<m; ++i)
            mst_skip[i] = true;
    }
    else {
        for (Py_ssize_t i=0; i<k-1; ++i) {
            DEADWOOD_ASSERT(mst_cut[i] >= 0 && mst_cut[i] < m);
            DEADWOOD_ASSERT(!mst_skip[mst_cut[i]]);
            mst_skip[mst_cut[i]] = true;
        }

        Py_ssize_t _k = size_getter.process();  // sets c and sizes based on the current mst_skip
        DEADWOOD_ASSERT(_k == k);

        std::vector<Py_ssize_t> edge_labels(m);
        for (Py_ssize_t i=0; i<m; ++i) {
            if (c[mst_i[2*i+0]]>=0 && c[mst_i[2*i+0]] == c[mst_i[2*i+1]])
                edge_labels[i] = c[mst_i[2*i+0]];
            else
                edge_labels[i] = -1;
        }

        std::vector<FLOAT> mst_d_grp(n);
        std::vector<Py_ssize_t> ind_grp(k+1);
        Csort_groups(mst_d, m, edge_labels.data(), k, mst_d_grp.data(), ind_grp.data());

        std::vector<FLOAT> weight_thresholds(k);
        for (Py_ssize_t i=0; i<k; ++i) {
            Py_ssize_t mi = sizes[i]-1;
            Py_ssize_t threshold_index;
            Cget_contamination(
                mst_d_grp.data()+ind_grp[i], mi, max_contamination, ema_dt,
                /*out*/contamination[i], /*out*/threshold_index
            );
            DEADWOOD_ASSERT(threshold_index>=0);
            if (threshold_index < mi)
                weight_thresholds[i] = mst_d_grp[ind_grp[i]+threshold_index];
            else
                weight_thresholds[i] = INFINITY;
        }

        for (Py_ssize_t i=0; i<m; ++i) {
            if (edge_labels[i] >= 0 && mst_d[i] >= weight_thresholds[edge_labels[i]])
                mst_skip[i] = true;
        }
    }


    size_getter.process();  // sets c and sizes based on the current mst_skip
    // DEADWOOD_PRINT("%d\n", _k);

    for (Py_ssize_t i=0; i<n; ++i) {
        DEADWOOD_ASSERT(c[i] >= 0 && c[i] < n);
        DEADWOOD_ASSERT(sizes[c[i]] > 0);
        if (sizes[c[i]] <= max_debris_size)
            c[i] = 1;
        else
            c[i] = 0;
    }
}


#if 0  /* /remove deprecated Cmst_trim_branches */

/** See Cmst_trim_branches below.  [DEPRECATED]
 */
template <class FLOAT> class CMSTBranchTrimmer : public CMSTProcessorBase
{
private:
    const FLOAT* mst_d;
    const FLOAT min_d;
    const Py_ssize_t max_size;


    std::vector<Py_ssize_t> size;

    Py_ssize_t clk;   // the number of connected components
    std::vector<Py_ssize_t> clsize;


    Py_ssize_t visit_get_sizes(Py_ssize_t v, Py_ssize_t e)
    {
        if (skip_edges && skip_edges[e]) return 0;

        Py_ssize_t iv = (Py_ssize_t)(mst_i[2*e+1]==v);
        Py_ssize_t w = mst_i[2*e+(1-iv)];

        DEADWOOD_ASSERT(e >= 0 && e < m);
        DEADWOOD_ASSERT(v >= 0 && v < n);
        DEADWOOD_ASSERT(w >= 0 && w < n);
        DEADWOOD_ASSERT(c[w] < 0);

        Py_ssize_t this_size = 1;
        c[w] = c[v];

        for (const Py_ssize_t* pe = inc+cumdeg[w]; pe != inc+cumdeg[w+1]; pe++) {
            if (*pe != e) this_size += visit_get_sizes(w, *pe);
        }

        size[2*e + (1-iv)] = this_size;
        size[2*e + iv] = -1;

        return this_size;
    }


    void visit_mark(Py_ssize_t v, Py_ssize_t e)
    {
        if (skip_edges && skip_edges[e]) return;

        Py_ssize_t iv = (Py_ssize_t)(mst_i[2*e+1]==v);
        Py_ssize_t w = mst_i[2*e+(1-iv)];

        if (c[w] < 0) return;  // already visited

        c[w] = -1;

        for (const Py_ssize_t* pe = inc+cumdeg[w]; pe != inc+cumdeg[w+1]; pe++) {
            if (*pe != e) visit_mark(w, *pe);
        }
    }


public:
    CMSTBranchTrimmer(
        const FLOAT* mst_d,
        FLOAT min_d,
        Py_ssize_t max_size,
        const Py_ssize_t* mst_i,
        Py_ssize_t m,
        Py_ssize_t n,
        Py_ssize_t* c,
        const Py_ssize_t* cumdeg=nullptr,
        const Py_ssize_t* inc=nullptr,
        const bool* skip_edges=nullptr
    ) :
        CMSTProcessorBase(mst_i, m, n, c, cumdeg, inc, skip_edges),
        mst_d(mst_d), min_d(min_d), max_size(max_size),
        size(2*m, -1)
    {
        DEADWOOD_ASSERT(this->c);
        DEADWOOD_ASSERT(this->cumdeg);
        DEADWOOD_ASSERT(this->inc);

        DEADWOOD_ASSERT(m == n-1);

        // the number of connected components:
        clk = 1;
        if (skip_edges) {
            for (Py_ssize_t e = 0; e < m; ++e)
                if (skip_edges[e]) clk++;
        }
        clsize.resize(clk);
    }


    void process()
    {
        for (Py_ssize_t v=0; v<n; ++v) c[v] = -1;
        for (Py_ssize_t i=0; i<clk; ++i) clsize[i] = 0;

        Py_ssize_t lastc = 0;
        for (Py_ssize_t v=0; v<n; ++v) {
            if (c[v] >= 0) continue;

            c[v] = lastc;
            Py_ssize_t this_size = 1;

            for (const Py_ssize_t* pe = inc+cumdeg[v]; pe != inc+cumdeg[v+1]; pe++) {
                if (skip_edges && skip_edges[*pe]) continue;
                this_size += visit_get_sizes(v, *pe);
            }
            clsize[lastc] = this_size;

            lastc++;
            if (lastc == clk) break;
        }

        DEADWOOD_ASSERT(lastc == clk);
        DEADWOOD_ASSERT(clk > 1 || clsize[0] == n);

        for (Py_ssize_t e=0; e<m; ++e) {
            if (skip_edges && skip_edges[e]) continue;
            DEADWOOD_ASSERT(size[2*e+0] > 0 || size[2*e+1] > 0);
            DEADWOOD_ASSERT(clsize[c[mst_i[2*e+0]]] == clsize[c[mst_i[2*e+1]]]);
            if (size[2*e+0] > 0)
                size[2*e+1] = clsize[c[mst_i[2*e+0]]] - size[2*e+0];
            else
                size[2*e+0] = clsize[c[mst_i[2*e+1]]] - size[2*e+1];
        }


        for (Py_ssize_t e=0; e<m; ++e) {
            if (skip_edges && skip_edges[e]) continue;
            if (mst_d[e] <= min_d) continue;

            Py_ssize_t iv = (size[2*e+0]>=size[2*e+1])?0:1;
            Py_ssize_t v = mst_i[2*e+iv];
            if (c[v] < 0) continue;
            if (size[2*e+(1-iv)] > max_size) continue;
            visit_mark(v, e);
        }
    }

};


/*! Trim tree branches of size <= max_size connected by an edge > min_d  [DEPRECATED]
 *
 *
 *  @param mst_d m edge weights
 *  @param mst_i c_contiguous matrix of size m*2,
 *     where {mst_i[k,0], mst_i[k,1]} specifies the k-th (undirected) edge
 *     in the spanning tree
 *  @param m number of rows in mst_i (edges)
 *  @param c [out] vector of length n; c[i] == -1 marks a trimmed-out point,
 *     whereas c[i] >= 0 denotes a retained one
 *  @param n length of c and the number of vertices in the spanning tree, n == m+1
 *  @param min_d minimal edge weight to be considered trimmable
 *  @param max_size maximal allowable size of a branch to cut
 *  @param mst_cumdeg an array of length n+1 or NULL; see Cgraph_vertex_incidences
 *  @param mst_inc an array of length 2*m or NULL; see Cgraph_vertex_incidences
 *  @param mst_skip Boolean array of length m or NULL; indicates the edges to skip
 */
template <class FLOAT>
void Cmst_trim_branches(
    const FLOAT* mst_d,
    FLOAT min_d,
    Py_ssize_t max_size,
    const Py_ssize_t* mst_i,
    Py_ssize_t m,
    Py_ssize_t n,
    Py_ssize_t* c,
    const Py_ssize_t* mst_cumdeg=nullptr,
    const Py_ssize_t* mst_inc=nullptr,
    const bool* mst_skip=nullptr
) {
    CMSTBranchTrimmer tr(mst_d, min_d, max_size, mst_i, m, n, c, cumdeg, inc, skip_edges);
    tr.process();  // modifies c in place
}


#endif  /* /remove deprecated Cmst_trim_branches */


/* ************************************************************************** */


#endif
