#! /bin/Rscript
library("optparse")

option_list <- list(
  make_option(opt_str = c("-i", "--input"), default = NULL, help = "Input bed-file. Last column must be sequences.", metavar = "character"),
  make_option(opt_str = c("-k", "--kmer"), default = 10, help = "Kmer length. Default = %default", metavar = "integer"),
  make_option(opt_str = c("-m", "--motif"), default = 10, help = "Estimated motif length. Default = %default", metavar = "integer"),
  make_option(opt_str = c("-o", "--output"), default = "reduced.bed", help = "Output file. Default = %default", metavar = "character"),
  make_option(opt_str = c("-t", "--threads"), default = 1, help = "Number of threads to use. Use -1 for all available cores. Default = %default", metavar = "integer"),
  make_option(opt_str = c("-c", "--clean"), default = TRUE, help = "Delete all temporary files. Default = %default", metavar = "logical"),
  make_option(opt_str = c("-s", "--min_seq_length"), default = NULL, help = "Remove sequences below this length. Defaults to motif length.", metavar = "integer", type = "integer")
  # TODO more args
)

opt_parser <- OptionParser(option_list = option_list, 
                           description = "Reduce sequences to frequent regions.")

opt <- parse_args(opt_parser)

#' Reduce bed file to conserved regions
#' 
#' @param input bed file
#' @param kmer Length of kmer.
#' @param motif Estimated motif length.
#' @param output Output file
#' @param threads Number of threads.
#' @param clean Delete all temporary files.
#' @param minoverlap_kmer Minimum required overlap between kmers. # TODO
#' @param minoverlap_motif Minimum required overlap between motif and kmer. # TODO
#' @param min_seq_length Must be smaller or equal to kmer and motif. Default = motif.
#' 
#' @return reduced bed
#' TODO check whether jellyfish is installed
reduce_bed <- function(input, kmer = 10, motif = 10, output = "reduced.bed", threads = NULL, clean = TRUE, minoverlap_kmer, minoverlap_motif, min_seq_length = motif) {
  # get number of available cores
  if (threads == -1) {
    threads <- parallel::detectCores()
  }
  
  message("Loading bed...")
  # load bed
  # columns: chr, start, end, name, ..., sequence
  bed_table <- data.table::fread(input = input, header = FALSE)
  names(bed_table)[1:4] <- c("chr", "start", "end", "name")
  names(bed_table)[ncol(bed_table)] <- "sequence"
  # index
  data.table::setkey(bed_table, name, physical = FALSE)
  
  # check for duplicated names
  if (anyDuplicated(bed_table[, "name"])) {
    warning("Found duplicated names. Making names unique.")
    bed_table[, name := make.unique(name)]
  }
  
  # remove sequences below minimum length
  total_rows <- nrow(bed_table)
  bed_table <- bed_table[nchar(sequence) > min_seq_length]
  if (total_rows > nrow(bed_table)) {
    message("Removed ", total_rows - nrow(bed_table), " sequence(s) below minimum length of ", min_seq_length)
  }
  
  # TODO forward fasta file as parameter so no bed -> fasta conversion is needed.
  message("Writing fasta...")
  # save as fasta
  fasta_file <- paste0(basename(input), ".fasta")
  seqinr::write.fasta(sequences = as.list(bed_table[[ncol(bed_table)]]), names = bed_table[[4]], as.string = TRUE, file.out = fasta_file)
  
  message("Counting kmer...")
  # count k-mer
  hashsize <- 4 ^ kmer
  count_output_binary <- "mer_count_binary.jf"
  input <- fasta_file
  jellyfish_call <- paste("jellyfish count ", "-m", kmer, "-s", hashsize, "-o", count_output_binary, input)

  system(command = jellyfish_call, wait = TRUE)
  
  mer_count_table <- "mer_count_table.jf"
  jellyfish_dump_call <- paste("jellyfish dump --column --tab --output", mer_count_table, count_output_binary)
  
  system(command = jellyfish_dump_call, wait = TRUE)

  message("Reduce kmer.")
  # load mer table
  # columns: kmer, count
  kmer_counts <- data.table::fread(input = mer_count_table, header = FALSE)
  # order kmer descending
  data.table::setorder(kmer_counts, -V2)
  
  # compute number of hits to keep
  keep_hits <- significant_kmer(bed_table, kmer = kmer, motif = motif)
  
  # reduce kmer
  reduced_kmer <- reduce_kmer(kmer = kmer_counts, keep_hits)

  message("Find kmer in sequences.")
  # find k-mer in sequences
  # TODO minoverlap as parameter
  # columns: name, start, end, width
  ranges_table <- find_kmer_regions(bed = bed_table, kmer_counts = reduced_kmer, minoverlap = kmer - 1, threads = threads)
  names(ranges_table)[1:2] <- c("relative_start", "relative_end")
  
  # merge ranged_table with bed_table + keep column order
  merged <- merge(x = bed_table, y = ranges_table, by = "name", sort = FALSE)[, union(names(bed_table), names(ranges_table)), with = FALSE]
  
  # delete sequences without hit
  merged <- na.omit(merged, cols = c("relative_start", "relative_end"))
  message("Removed ", nrow(bed_table) - nrow(merged), " sequence(s) without hit.")
  
  message("Reduce sequences.")
  # create subsequences
  merged[, sequence := stringr::str_sub(sequence, relative_start, relative_end)]  
  
  # bed files count from 0
  merged[, `:=`(relative_start = relative_start - 1, relative_end = relative_end - 1)]
  # change start end location
  merged[, `:=`(start = start + relative_start, end = start + relative_end)]
  
  # clean table
  merged[, `:=`(relative_start = NULL, relative_end = NULL, width = NULL)]
  
  if (clean) {
    file.remove(fasta_file, count_output_binary, mer_count_table)
  }
  
  data.table::fwrite(merged, file = output, sep = "\t", col.names = FALSE)
}

#' returns sum of top x kmer frequencies to keep
#' 
#' @param bed Bed table with sequences in last column
#' @param kmer Length of kmer
#' @param motif Length of motif
#' @param minoverlap Minimum number of bases overlapping between kmer and motif. Must be <= motif & <= kmer. Defaults to ceiling(motif / 2).
#' 
#' @return Number of interesting kmer.
significant_kmer <- function(bed, kmer, motif, minoverlap = ceiling(motif / 2)) {
  if (minoverlap > kmer || minoverlap > motif) {
    stop("Kmer & motif must be greater or equal than minoverlap!")
  }
  
  # minimum sequence length to get all interesting overlaps
  min_seq_length <- motif + 2 * (kmer - minoverlap)
  
  seq_lengths <- nchar(bed[[ncol(bed)]])
  
  # reduce to max interesting length
  seq_lengths <- ifelse(seq_lengths > min_seq_length, min_seq_length, seq_lengths)
  
  # calculate max possible kmer
  topx <- sum(seq_lengths - kmer + 1)
  
  return(topx)
}

#' @param kmer Kmer table
#' @param significant
reduce_kmer <- function(kmer, significant) {
  kmer[, cumsum := cumsum(V2)]
  
  return(kmer[cumsum <= significant])
}

#' create list of significant ranges (one for each bed entry)
#' 
#' @param bed Data.table of bed with sequence in last column
#' @param kmer_counts Data.table of counted kmer. Column1 = kmer, column2 = count.
#' @param minoverlap Minimum overlapping nucleotides between kmers to be merged. Positive integer. Must be smaller than kmer length.
#' @param threads Number of threads.
#' 
#' @return Data.table with relative positions and width (start, end, width).
find_kmer_regions <- function(bed, kmer_counts, minoverlap = 1 , threads = NULL) {
  if (nchar(kmer_counts[1, 1]) <= minoverlap) {
    stop("Minoverlap must be smaller than kmer length!")
  }
  
  names(kmer_counts)[1:2] <- c("kmer", "count")
  data.table::setorder(kmer_counts, -count)
  
  seq_ranges <- pbapply::pblapply(seq_len(nrow(bed)), cl = threads, function(x) {
    seq <- bed[x][[ncol(bed)]]
    
    #### locate ranges
    ranges <- data.table::data.table(do.call(rbind, stringi::stri_locate_all_fixed(seq, pattern = kmer_counts[[1]])))
    
    ranges <- na.omit(ranges, cols = c("start", "end"))
    
    if (nrow(ranges) < 1) {
      return(data.table::data.table(start = NA, end = NA, width = NA))
    }
    
    # add kmer sequences
    ranges[, sub_seq := stringr::str_sub(seq, start, end)]
    # add kmer count
    ranges[, count := kmer_counts[ranges[["sub_seq"]], "count", on = "kmer"]]
    
    #### reduce ranges
    reduced_ranges <- IRanges::IRanges(start = ranges[["start"]], end = ranges[["end"]], names = ranges[["sub_seq"]])
    
    # list of overlapping ranges
    edge_list <- as.matrix(IRanges::findOverlaps(reduced_ranges, minoverlap = minoverlap, drop.self = FALSE, drop.redundant = TRUE))
    
    # get components (groups of connected ranges)
    graph <- igraph::graph_from_edgelist(edge_list, directed = FALSE)
    # vector of node membership (indices correspond to ranges above)
    member <- as.factor(igraph::components(graph)$membership)
    
    # list of membership vectors
    node_membership <- lapply(levels(member), function(x) {
      which(member == x)
    })
    
    # calculate component score (= sum of kmer count)
    score <- vapply(node_membership, FUN.VALUE = numeric(1), function(x) {
      sum(kmer_counts[x, "count"])
    })
    
    selected_ranges <- node_membership[[which(score == max(score))[1]]]
    
    # reduce selected ranges
    reduced_ranges <- IRanges::reduce(reduced_ranges[selected_ranges])
    
    reduced_ranges <- data.table::as.data.table(reduced_ranges)
    
    return(reduced_ranges)
  })
  
  # create ranges table
  conserved_regions_table <- data.table::rbindlist(seq_ranges)
  conserved_regions_table[, name := bed[[4]]]
  
  return(conserved_regions_table)
}

# call function with given parameter if not in interactive context (e.g. run from shell)
if (!interactive()) {
  # show apply progressbar
  pbo <- pbapply::pboptions(type = "timer")
  # remove last parameter (help param)
  params <- opt[-length(opt)]
  do.call(reduce_bed, args = params)  
}