#===============================================================================
# Temporal Forest - Core Functions
# Author: Sisi Shao
# This file contains the main implementation of the TemporalTree_time algorithm
# and its internal helper function, get_split_names.
#===============================================================================

#' Get Split Variable Names from a partykit Tree (Internal)
#'
#' Extracts the names of the variables used in the splits of a tree object.
#'
#' @param tree A partykit tree object.
#' @return A character vector of unique variable names used in splits.
#' @keywords internal

.list.rules.party_safe <- function(tree) {
    # same target function, just accessed via namespace instead of :::
    utils::getFromNamespace(".list.rules.party", "partykit")(tree)
}
get_split_names <- function(tree, data) {
    if (FALSE) { partykit::ctree(NULL) }
    if (is.null(tree) || inherits(tree, "partynone")) return(character(0))
    # FIX: Revert to using ::: for non-exported functions
    paths <- .list.rules.party_safe(tree)
    vnames <- names(data)
    split_vars <- vnames[sapply(vnames, function(var) {
        pattern <- paste0("\\b", var, "\\b\\s*(<|>|<=|>=|==|!=)")
        any(grepl(pattern, paths))
    })]
    return(unique(split_vars))
}





#' Core Temporal Forest Algorithm (Internal)
#'
#' This is the internal workhorse function called by the `temporal_forest` wrapper.
#' It is not intended for direct use by end-users.
#'
#' @keywords internal
#' @importFrom dynamicTreeCut cutreeDynamic
#' @importFrom WGCNA labels2colors
#' @importFrom glmertree lmertree
#' @importFrom stats as.formula as.dist
TemporalTree_time <- function(data, A_combined, fixed_regress=NULL, fixed_split=NULL, var_select=NULL, cluster=NULL,
                              maxdepth_factor_screen=0.04, maxdepth_factor_select=0.5,
                              use_fuzzy=TRUE, minsize_multiplier=1, alpha_screen=0.2, alpha_select=0.05,
                              alpha_predict=0.05, minClusterSize=4,
                              n_boot_screen = 50,
                              keep_fraction_screen = 0.2,
                              n_boot_select = 100,
                              number_selected_final = 10
) {
    
    # Clustering based on dissimilarity matrix
    if (!inherits(A_combined, "dist")) { A_combined_dist <- as.dist(A_combined) } else { A_combined_dist <- A_combined }
    geneTree <- flashClust::hclust(A_combined_dist, method = "average")
    # FIX: suppressMessages is a base function and does not need to be imported
    dynamicMods <- suppressMessages(cutreeDynamic(dendro = geneTree, distM = as.matrix(A_combined),
                                                  deepSplit = 2, pamRespectsDendro = FALSE,
                                                  minClusterSize = minClusterSize))
    colors <- labels2colors(dynamicMods)
    if(is.null(names(colors))) names(colors) <- var_select
    module_names <- unique(colors)
    module_dic <- list()
    for (i in 1:length(module_names)) {
        module_dic[[module_names[i]]] <- names(colors[colors == module_names[i]])
    }
    
    # Robust Screening within modules
    screened_features_list <- list()
    valid_regress <- fixed_regress[!is.null(fixed_regress) & fixed_regress %in% names(data)]
    valid_cluster <- cluster[!is.null(cluster) & cluster %in% names(data)]
    if(length(valid_cluster) == 0) stop("Cluster variable not found in data.")
    
    cluster_ids_global <- unique(data[[cluster]])
    n_clusters_global <- length(cluster_ids_global)
    
    for (name in module_names) {
        split_var_module <- module_dic[[name]]
        if(length(split_var_module) == 0) { next }
        
        valid_split_module <- split_var_module[split_var_module %in% names(data)]
        if(length(valid_split_module) == 0){ next }
        
        maxdepth <- max(2, ceiling(maxdepth_factor_screen * length(valid_split_module)))
        formula_lhs <- "y ~"
        formula_regress_str <- if(length(valid_regress) > 0) paste(valid_regress, collapse = "+") else "1"
        formula_cluster_str <- paste("|", valid_cluster, "|")
        formula_split_str <- paste(paste0("`", valid_split_module, "`"), collapse = "+")
        formula_screen_mod <- as.formula(paste(formula_lhs, formula_regress_str, formula_cluster_str, formula_split_str))
        
        module_bootstrap_splitters <- vector("list", n_boot_screen)
        for (i_boot_screen in 1:n_boot_screen) {
            boot_cluster_ids_screen <- sample(cluster_ids_global, size = n_clusters_global, replace = TRUE)
            boot_data_indices_screen <- which(data[[cluster]] %in% boot_cluster_ids_screen)
            boot_data_screen <- data[boot_data_indices_screen, ]
            current_splitters <- tryCatch({
                screen_tree <- glmertree::lmertree(formula_screen_mod, data = boot_data_screen, alpha = alpha_screen, maxdepth = maxdepth, minsize = 10)
                tree_splitters <- get_split_names(screen_tree$tree, boot_data_screen)
                intersect(tree_splitters, split_var_module)
            }, error = function(e) { character(0) })
            module_bootstrap_splitters[[i_boot_screen]] <- current_splitters
        }
        
        valid_module_splitters_list <- unlist(module_bootstrap_splitters)
        if (length(valid_module_splitters_list) > 0) {
            module_splitter_counts <- table(valid_module_splitters_list)
            ranked_module_splitters <- names(sort(module_splitter_counts, decreasing = TRUE))
            num_in_module <- length(split_var_module)
            num_to_keep_screen <- max(1, floor(num_in_module * keep_fraction_screen))
            num_to_keep_screen <- min(num_to_keep_screen, length(ranked_module_splitters))
            screened_features_list[[name]] <- ranked_module_splitters[1:num_to_keep_screen]
        }
    }
    screened_candidates <- unique(unlist(screened_features_list))
    
    # Robust Selection from screened candidates
    final_selection <- character(0)
    second_stage_splitters <- character(0)
    
    if (length(screened_candidates) > 0) {
        all_bootstrap_splitters <- vector("list", n_boot_select)
        valid_split_select <- screened_candidates[screened_candidates %in% names(data)]
        
        if(length(valid_split_select) > 0) {
            maxdepth_select_calc <- max(2, ceiling(maxdepth_factor_select * length(valid_split_select)))
            formula_split_select_str <- paste(paste0("`", valid_split_select, "`"), collapse = "+")
            formula_select <- as.formula(paste(formula_lhs, formula_regress_str, formula_cluster_str, formula_split_select_str))
            
            for (i_boot in 1:n_boot_select) {
                boot_cluster_ids <- sample(cluster_ids_global, size = n_clusters_global, replace = TRUE)
                boot_data_indices <- which(data[[cluster]] %in% boot_cluster_ids)
                boot_data <- data[boot_data_indices, ]
                boot_splitters <- tryCatch({
                    boot_tree <- glmertree::lmertree(formula_select, data = boot_data, alpha = alpha_select, maxdepth = maxdepth_select_calc, minsize = 10)
                    tree_splitters <- get_split_names(boot_tree$tree, boot_data)
                    intersect(tree_splitters, screened_candidates)
                }, error = function(e) { character(0) })
                all_bootstrap_splitters[[i_boot]] <- boot_splitters
            }
            
            valid_splitters_list <- unlist(all_bootstrap_splitters)
            second_stage_splitters <- unique(valid_splitters_list)
            
            if (length(valid_splitters_list) > 0) {
                splitter_counts <- table(valid_splitters_list)
                ranked_splitters <- names(sort(splitter_counts, decreasing = TRUE))
                num_to_select <- min(number_selected_final, length(ranked_splitters))
                final_selection <- ranked_splitters[1:num_to_select]
            }
        }
    }
    
    return(list(final_selection = final_selection,
                second_stage_splitters = second_stage_splitters,
                status = "Completed"))
}