#' @title
#' simdata
#'
#' @description
#' Simulation of a multi-state model with \eqn{n} no. of states.
#'
#' @details This function is used for simulating a multi-state model with \eqn{n} no. of states
#' and status corresponding to each state along with a number of covariates both continuous or categorical.
#'
#' @param seed Random seed for reproducibility
#' @param n Number of subjects
#' @param dist distribution to follow for baseline hazard ("exponential", "weibull", "gompertz")
#' @param cdist distribution to follow for censoring distribution ("uniform", "exponential", "weibull")
#' @param cparams parameter vector for censoring distribution
#' @param lambdas scale parameter of the baseline distribution
#' @param gammas shape parameter of the baseline distribution
#' @param beta_list a list containing coefficients for the covariates to be generated, each value corresponds to one transition
#' @param cov_means mean value of each of the covariates
#' @param cov_sds standard deviation of each of the covariates
#' @param trans_list transition matrix of the multi-state model based on number of states
#' @param state_names states of the multi-state model
#' @return a multi-state dataframe with given number of states, corresponding status and the covariate vector
#' @import mstate
#' @importFrom stats rnorm
#' @importFrom survival Surv
#'
#' @examples
#' ##
#' msdata_4state <- simdata(seed=123,n=1000,dist="weibull",cdist="exponential",
#'                  cparams=list(rate = 0.1),lambdas=c(0.1, 0.2, 0.3, 0.4),
#'                  gammas=c(1.5, 2, 2.5, 2.6),beta_list=list(c(-0.05, 0.01, 0.5, 0.6),
#'                  c(-0.03, 0.02, 0.07, 0.08),c(-0.04, 0.03, 0.04, -0.03),
#'                  c(-0.05, 0.05, 0.6, 0.8)),cov_means=c(0, 10, 2, 3),cov_sds=c(1,20,5,1.05),
#'                  trans_list=list(c(2, 3, 4, 5),c(3, 4, 5),c(4, 5), c(5), c()),
#'                  state_names=c("Tx", "Rec", "Death", "Reldeath", "srv"))
#' ##
#' @export
#' @author Atanu Bhattacharjee,Gajendra Kumar Vishwakarma,Abhipsa Tripathy

simdata <- function(seed = 123, n = 1000,dist = "weibull",cdist = "exponential",cparams = list(rate = 0.1),
                    lambdas, gammas, beta_list,cov_means, cov_sds,trans_list, state_names) {

  set.seed(seed)

  # --- Check inputs ---
  if (length(lambdas) != length(beta_list)) {
    stop("Length of lambdas must equal number of elements in beta_list")
  }
  if (!missing(gammas) && dist != "exponential" && length(gammas) != length(lambdas)) {
    stop("Length of gammas must equal length of lambdas for non-exponential distributions")
  }

  # --- Generate covariates ---
  num_cov <- length(cov_means)
  cov_df <- data.frame(id = 1:n)
  for (i in seq_len(num_cov)) {
    cov_df[[paste0("x", i)]] <- rnorm(n, mean = cov_means[i], sd = cov_sds[i])
  }
  cov_names <- paste0("x", seq_len(num_cov))

  # --- Generate censoring times ---
  if (cdist == "uniform") {
    ctime <- runif(n,
                   min = ifelse(is.null(cparams$min), 0, cparams$min),
                   max = ifelse(is.null(cparams$max), 10, cparams$max))
  } else if (cdist == "exponential") {
    ctime <- rexp(n, rate = ifelse(is.null(cparams$rate), 0.5, cparams$rate))
  } else if (cdist == "weibull") {
    ctime <- rweibull(n,
                      shape = ifelse(is.null(cparams$shape), 0.5, cparams$shape),
                      scale = ifelse(is.null(cparams$scale), 0.2, cparams$scale))
  } else {
    stop("Invalid censoring distribution")
  }

  # --- Simulate event times ---
  data_matrix <- as.matrix(cov_df[, cov_names])
  num_transitions <- length(lambdas)
  survival_times <- matrix(0, n, num_transitions)
  status_matrix <- matrix(0, n, num_transitions)

  for (i in seq_len(num_transitions)) {
    u <- runif(n)
    lu <- log(u)
    xbeta <- exp(data_matrix %*% beta_list[[i]])

    if (dist == "exponential") {
      survival_times[, i] <- -lu / (lambdas[i] * xbeta)
    } else if (dist == "weibull") {
      survival_times[, i] <- (-lu / (lambdas[i] * xbeta))^(1 / gammas[i])
    } else if (dist == "gompertz") {
      survival_times[, i] <- (1 / gammas[i]) *
        log(1 - ((gammas[i] * lu) / (lambdas[i] * xbeta)))
    }

    status_matrix[, i] <- as.numeric(survival_times[, i] < ctime)
  }

  # --- Prepare dataset ---
  cumulative_times <- t(apply(survival_times, 1, cumsum))
  data_list <- list(id = cov_df$id)

  for (i in seq_len(num_transitions)) {
    data_list[[paste0("time", i)]] <- cumulative_times[, i]
    data_list[[paste0("stat", i)]] <- status_matrix[, i]
  }

  data_final <- cbind(data.frame(data_list), cov_df[, cov_names])

  # --- Transition matrix ---
  tmat <- mstate::transMat(x = trans_list, names = state_names)

  # --- Prepare mstate format ---
  time_vars   <- c(NA, paste0("time", seq_along(lambdas)))
  status_vars <- c(NA, paste0("stat", seq_along(lambdas)))

  msdata <- suppressWarnings(mstate::msprep(
    time   = time_vars,
    status = status_vars,
    data   = data_final,
    trans  = tmat,
    keep   = cov_names
  ))

  msdata <- mstate::expand.covs(msdata, cov_names, append = TRUE, longnames = FALSE)

  return(msdata)
}
