#! /bin/Rscript
library("optparse")

option_list <- list(
  make_option(opt_str = c("-i", "--input"), default = NULL, help = "Input bed-file. Last column must be sequences.", metavar = "character"),
  make_option(opt_str = c("-k", "--kmer"), default = 10, help = "K-mer length. Default = %default", metavar = "integer"),
  make_option(opt_str = c("-m", "--motif"), default = 10, help = "Estimated motif length. Default = %default", metavar = "integer"),
  make_option(opt_str = c("-o", "--output"), default = "reduced.bed", help = "Output file. Default = %default", metavar = "character"),
  make_option(opt_str = c("-t", "--threads"), default = 1, help = "Number of threads to use. Use 0 for all available cores. Default = %default", metavar = "integer"),
  make_option(opt_str = c("-c", "--clean"), default = TRUE, help = "Delete all temporary files. Default = %default", metavar = "logical"),
  make_option(opt_str = c("-s", "--min_seq_length"), default = NULL, help = "Remove sequences below this length. Defaults to the maximum value of motif and k-mer and can not be lower.", metavar = "integer", type = "integer"),
  make_option(opt_str = c("-n", "--minoverlap_kmer"), default = NULL, help = "Minimum required overlap between k-mer. Used to create reduced sequence ranges out of merged k-mer. Can not be greater than k-mer length. Default = kmer - 1", metavar = "integer", type = "integer"),
  make_option(opt_str = c("-v", "--minoverlap_motif"), default = NULL, help = "Minimum required overlap between motif and k-mer to consider k-mer significant. Used for k-mer cutoff calculation. Can not be greater than motif and k-mer length. Default = ceiling(motif / 2)", metavar = "integer", type = "integer"),
  make_option(opt_str = c("-f", "--motif_occurrence"), default = 1, help = "Define how many motifs are expected per sequence. This value is used during k-mer cutoff calculation. Default = %default meaning that there should be approximately one motif per sequence.", metavar = "double")
)

opt_parser <- OptionParser(option_list = option_list, 
                           description = "Reduces each sequence to its most frequent region.",
                           epilogue = "Author: Hendrik Schultheis <Hendrik.Schultheis@mpi-bn.mpg.de>")

opt <- parse_args(opt_parser)

#' Reduces each sequence to its most frequent region.
#' 
#' @param input Input bed-file. Last column must be sequences.
#' @param kmer k-mer length. Default = 10
#' @param motif Estimated motif length. Default = 10
#' @param output Output file. Default = reduced.bed
#' @param threads Number of threads to use. Default = 1. Use 0 for all cores.
#' @param clean Delete all temporary files.
#' @param minoverlap_kmer Minimum required overlap between k-mer. Used to create reduced sequence ranges out of merged k-mer. Can not be greater than k-mer length. Default = kmer - 1
#' @param minoverlap_motif Minimum required overlap between motif and k-mer to consider k-mer significant. Used for k-mer cutoff calculation. Can not be greater than motif and k-mer length. Default = ceiling(motif / 2)
#' @param min_seq_length Remove sequences below this length. Defaults to the maximum value of motif and k-mer and can not be lower.
#' @param motif_occurrence Define how many motifs are expected per sequence. This value is used during k-mer cutoff calculation. Default = 1 meaning that there should be approximately one motif per sequence.
#' 
#' @details If there is a header supplied other then the default data.table naming scheme ('V1', 'V2', etc.) it will be kept.
#' 
reduce_sequence <- function(input, kmer = 10, motif = 10, output = "reduced.bed", threads = NULL, clean = TRUE, minoverlap_kmer = kmer - 1, minoverlap_motif = ceiling(motif / 2), min_seq_length = max(c(motif, kmer)), motif_occurrence = 1) {
  if (system("which jellyfish", ignore.stdout = TRUE) != 0) {
    stop("Required program jellyfish not found! Please check whether it is installed.")
  }
  
  if (missing(input)) {
    stop("No input specified! Please forward a valid bed-file.")
  }
  
  # get number of available cores
  if (threads == 0) {
    threads <- parallel::detectCores()
  }
  
  message("Loading bed...")
  # load bed
  # columns: chr, start, end, name, ..., sequence
  bed_table <- data.table::fread(input = input)
  
  # check for header and save it if provided
  default_col_names <- grepl(pattern = "^V+\\d$", names(bed_table), perl = TRUE)
  if (!any(default_col_names)) {
    keep_col_names <- TRUE
    col_names <- names(bed_table)
  } else {
    keep_col_names <- FALSE
  }
  
  names(bed_table)[1:4] <- c("chr", "start", "end", "name")
  names(bed_table)[ncol(bed_table)] <- "sequence"
  # index
  data.table::setkey(bed_table, name, physical = FALSE)
  
  # check for duplicated names
  if (anyDuplicated(bed_table[, "name"])) {
    warning("Found duplicated names. Making names unique.")
    bed_table[, name := make.unique(name)]
  }
  
  # remove sequences below minimum length
  if (min_seq_length < max(c(kmer, motif))) {
    stop("Minimum sequence length must be greater or equal to ", max(c(motif, kmer)), " (maximum value of k-mer and motif).")
  }
  
  total_rows <- nrow(bed_table)
  bed_table <- bed_table[nchar(sequence) > min_seq_length]
  if (total_rows > nrow(bed_table)) {
    message("Removed ", total_rows - nrow(bed_table), " sequence(s) below minimum length of ", min_seq_length)
  }
  
  # TODO forward fasta file as parameter so no bed -> fasta conversion is needed.
  message("Writing fasta...")
  # save as fasta
  fasta_file <- paste0(basename(input), ".fasta")
  seqinr::write.fasta(sequences = as.list(bed_table[[ncol(bed_table)]]), names = bed_table[[4]], as.string = TRUE, file.out = fasta_file)
  
  message("Counting k-mer...")
  # count k-mer
  hashsize <- 4 ^ kmer
  count_output_binary <- "mer_count_binary.jf"
  input <- fasta_file
  jellyfish_call <- paste("jellyfish count ", "-m", kmer, "-s", hashsize, "-o", count_output_binary, input)

  system(command = jellyfish_call, wait = TRUE)
  
  mer_count_table <- "mer_count_table.jf"
  jellyfish_dump_call <- paste("jellyfish dump --column --tab --output", mer_count_table, count_output_binary)
  
  system(command = jellyfish_dump_call, wait = TRUE)

  message("Reduce k-mer.")
  # load mer table
  # columns: kmer, count
  kmer_counts <- data.table::fread(input = mer_count_table, header = FALSE)
  # order k-mer descending
  data.table::setorder(kmer_counts, -V2)
  
  # compute number of hits to keep
  keep_hits <- significant_kmer(bed_table, kmer = kmer, motif = motif, minoverlap = minoverlap_motif, motif_occurrence = motif_occurrence)
  
  # reduce k-mer
  reduced_kmer <- reduce_kmer(kmer = kmer_counts, significant = keep_hits)

  message("Find k-mer in sequences.")
  # find k-mer in sequences
  # columns: name, start, end, width
  ranges_table <- find_kmer_regions(bed = bed_table, kmer_counts = reduced_kmer, minoverlap = minoverlap_kmer, threads = threads)
  names(ranges_table)[1:2] <- c("relative_start", "relative_end")
  
  # merge ranged_table with bed_table + keep column order
  merged <- merge(x = bed_table, y = ranges_table, by = "name", sort = FALSE)[, union(names(bed_table), names(ranges_table)), with = FALSE]
  
  # delete sequences without hit
  merged <- na.omit(merged, cols = c("relative_start", "relative_end"))
  message("Removed ", nrow(bed_table) - nrow(merged), " sequence(s) without hit.")
  
  message("Reduce sequences.")
  # create subsequences
  merged[, sequence := stringr::str_sub(sequence, relative_start, relative_end)]  
  
  # bed files count from 0
  merged[, `:=`(relative_start = relative_start - 1, relative_end = relative_end - 1)]
  # change start end location
  merged[, `:=`(start = start + relative_start, end = start + relative_end)]
  
  # clean table
  merged[, `:=`(relative_start = NULL, relative_end = NULL, width = NULL)]
  
  if (clean) {
    file.remove(fasta_file, count_output_binary, mer_count_table)
  }
  
  # keep provided column names
  if (keep_col_names) {
    names(merged) <- col_names
  }
  
  data.table::fwrite(merged, file = output, sep = "\t", col.names = keep_col_names)
}

#' Predict how many interesting k-mer are possible for the given data.
#' 
#' @param bed Bed table with sequences in last column
#' @param kmer Length of k-mer
#' @param motif Length of motif
#' @param minoverlap Minimum number of bases overlapping between k-mer and motif. Must be <= motif & <= kmer. Defaults to ceiling(motif / 2).
#' @param motif_occurrence Define how many motifs are expected per sequence. Default = 1
#' 
#' @return Number of interesting k-mer.
significant_kmer <- function(bed, kmer, motif, minoverlap = ceiling(motif / 2), motif_occurrence = 1) {
  if (minoverlap > kmer || minoverlap > motif) {
    stop("Kmer & motif must be greater or equal to minoverlap!")
  }
  if (motif_occurrence <= 0) {
    stop("Motif_occurrence must be a numeric value above 0!")
  }
  
  # minimum sequence length to get all interesting overlaps
  min_seq_length <- motif + 2 * (kmer - minoverlap)
  
  seq_lengths <- nchar(bed[[ncol(bed)]])
  
  # reduce to max interesting length
  seq_lengths <- ifelse(seq_lengths > min_seq_length, min_seq_length, seq_lengths)
  
  # calculate max possible k-mer
  topx <- sum(seq_lengths - kmer + 1)
  
  return(topx * motif_occurrence)
}

#' Orders k-mer table after count descending and keeps all k-mer with a cumulative sum below the given significance threshold.
#' 
#' @param kmer K-mer data.table columns: kmer, count
#' @param significant Value from significant_kmer function.
#' 
#' @return reduced data.table
reduce_kmer <- function(kmer, significant) {
  data.table::setorderv(kmer, cols = names(kmer)[2], order = -1)
  
  # TODO don't use 'V2'
  kmer[, cumsum := cumsum(V2)]
  
  return(kmer[cumsum <= significant])
}

#' create list of significant ranges (one for each bed entry)
#' 
#' @param bed Data.table of bed with sequence in last column
#' @param kmer_counts Data.table of counted k-mer. Column1 = kmer, column2 = count.
#' @param minoverlap Minimum overlapping nucleotides between k-mers to be merged. Positive integer. Must be smaller than k-mer length.
#' @param threads Number of threads.
#' 
#' @return Data.table with relative positions and width (start, end, width).
#' 
#' TODO Include number of motifs per sequence (aka motif_occurrence). Attempt to keep best 2 regions for occurrence = 2? Probably high impact on performance.
find_kmer_regions <- function(bed, kmer_counts, minoverlap = 1 , threads = NULL) {
  if (nchar(kmer_counts[1, 1]) <= minoverlap) {
    stop("Minoverlap must be smaller than k-mer length!")
  }
  
  names(kmer_counts)[1:2] <- c("kmer", "count")
  data.table::setorder(kmer_counts, -count)
  
  seq_ranges <- pbapply::pblapply(seq_len(nrow(bed)), cl = threads, function(x) {
    seq <- bed[x][[ncol(bed)]]
    name <- bed[x][[4]]
    
    #### locate ranges
    ranges <- data.table::data.table(do.call(rbind, stringi::stri_locate_all_fixed(seq, pattern = kmer_counts[[1]])))
    
    ranges <- na.omit(ranges, cols = c("start", "end"))
    
    if (nrow(ranges) < 1) {
      return(data.table::data.table(start = NA, end = NA, width = NA, name = name))
    }
    
    # add k-mer sequences
    ranges[, sub_seq := stringr::str_sub(seq, start, end)]
    # add k-mer count
    ranges[, count := kmer_counts[ranges[["sub_seq"]], "count", on = "kmer"]]
    
    #### reduce ranges
    reduced_ranges <- IRanges::IRanges(start = ranges[["start"]], end = ranges[["end"]], names = ranges[["sub_seq"]])
    
    # list of overlapping ranges
    edge_list <- as.matrix(IRanges::findOverlaps(reduced_ranges, minoverlap = minoverlap, drop.self = FALSE, drop.redundant = TRUE))
    
    # get components (groups of connected ranges)
    graph <- igraph::graph_from_edgelist(edge_list, directed = FALSE)
    # vector of node membership (indices correspond to ranges above)
    member <- as.factor(igraph::components(graph)$membership)
    
    # list of membership vectors
    node_membership <- lapply(levels(member), function(x) {
      which(member == x)
    })
    
    # calculate component score (= sum of k-mer count)
    score <- vapply(node_membership, FUN.VALUE = numeric(1), function(x) {
      sum(kmer_counts[x, "count"])
    })
    
    selected_ranges <- node_membership[[which(score == max(score))[1]]]
    
    # reduce selected ranges
    reduced_ranges <- IRanges::reduce(reduced_ranges[selected_ranges])
    
    reduced_ranges <- data.table::as.data.table(reduced_ranges)[, name := name]
    
    return(reduced_ranges)
  })
  
  # create ranges table
  conserved_regions_table <- data.table::rbindlist(seq_ranges)
  
  return(conserved_regions_table)
}

# call function with given parameter if not in interactive context (e.g. run from shell)
if (!interactive()) {
  # show apply progressbar
  pbo <- pbapply::pboptions(type = "timer")
  # remove last parameter (help param)
  params <- opt[-length(opt)]
  do.call(reduce_sequence, args = params)  
}