#!/usr/bin/env python3

import argparse
import multiprocessing
import sys
from functools import partial
from math import log2
import pysam


def get_args():
    '''This function parses and return arguments passed in'''
    parser = argparse.ArgumentParser(
        prog='normalizedReadCount',
        description='Counts reads aligned to genome, and normalize by genome size')
    parser.add_argument(
        '-ab1',
        dest='abam1',
        default=None,
        help="PMDTools Bam aligment file on genome 1. Default = None")
    parser.add_argument(
        '-b1',
        dest='bam1',
        default=None,
        help="Bam aligment file on genome 1. Default = None")
    parser.add_argument(
        '-g1',
        dest="genome1",
        default=None,
        help="Fasta file of genome 1. Default = None")
    parser.add_argument(
        '-r1',
        dest='organism1',
        default=None,
        help='Organism 1 name. Example: "Homo_sapiens". Default = None'
    )
    parser.add_argument(
        '-ab2',
        dest='abam2',
        default=None,
        help="PMDTools Bam aligment file on genome 2. Default = None")
    parser.add_argument(
        '-b2',
        dest='bam2',
        default=None,
        help="Bam aligment file on genome 2  . Default = None")
    parser.add_argument(
        '-g2',
        dest="genome2",
        default=None,
        help="Fasta file of genome 2. Default = None")
    parser.add_argument(
        '-r2',
        dest='organism2',
        default=None,
        help='Organism 2 name. Example: "Homo_sapiens". Default = None'
    )
    parser.add_argument(
        '-ab3',
        dest='abam3',
        default=None,
        help="PMDTools Bam aligment file on genome 3. Default = None")
    parser.add_argument(
        '-b3',
        dest='bam3',
        default=None,
        help="Bam aligment file on genome 3. Default = None")
    parser.add_argument(
        '-g3',
        dest="genome3",
        default=None,
        help="Fasta file of genome 3. Default = None")
    parser.add_argument(
        '-r3',
        dest='organism3',
        default=None,
        help='Organism 3 name. Example: "Homo_sapiens". Default = None'
    )
    parser.add_argument(
        '-n',
        dest='name',
        default=None,
        help='Sample name. Default = None'
    )
    parser.add_argument(
        '-i',
        dest='identity',
        default=0.95,
        help='Identity threshold to retain read alignment. Default = 0.95'
    )
    parser.add_argument(
        '-o',
        dest="output",
        default=None,
        help="Output file basename. Default = {BAM_GENOME}.out"
    )
    parser.add_argument(
        '-ob1',
        dest="output_bam1",
        default=None,
        help="Output bam 1 filename. Default = {BAM1_INPUT}.filtered.bam"
    )
    parser.add_argument(
        '-aob1',
        dest="output_abam1",
        default=None,
        help="Output PMDtools bam 1 filename. Default = {BAM1_INPUT}.ancient.filtered.bam"
    )
    parser.add_argument(
        '-ob2',
        dest="output_bam2",
        default=None,
        help="Output bam 2 filename. Default = {BAM2_INPUT}.filtered.bam"
    )
    parser.add_argument(
        '-aob2',
        dest="output_abam2",
        default=None,
        help="Output PMDtools bam 2 filename. Default = {BAM2_INPUT}.ancient.filtered.bam"
    )
    parser.add_argument(
        '-ob3',
        dest="output_bam3",
        default=None,
        help="Output bam 3 filename. Default = {BAM2_INPUT}.filtered.bam"
    )
    parser.add_argument(
        '-aob3',
        dest="output_abam3",
        default=None,
        help="Output PMDtools bam 3 filename. Default = {BAM2_INPUT}.ancient.filtered.bam"
    )
    parser.add_argument(
        '-ed1',
        dest="endo_dna1",
        default=0.01,
        help="Proportion of endogenous DNA in microbiome for organism 1. Default = 0.05"
    )
    parser.add_argument(
        '-ed2',
        dest="endo_dna2",
        default=0.01,
        help="Proportion of endogenous DNA in microbiome for organism 2. Default = 0.05"
    )
    parser.add_argument(
        '-ed3',
        dest="endo_dna3",
        default=0.01,
        help="Proportion of endogenous DNA in microbiome for organism 3. Default = 0.05"
    )
    parser.add_argument(
        '-p',
        dest="processes",
        default=4,
        help="Number of parallel process. Default = 4",
    )

    args = parser.parse_args()

    abam1 = args.abam1
    bam1 = args.bam1
    genome1 = args.genome1
    organame1 = args.organism1
    abam2 = args.abam2
    bam2 = args.bam2
    genome2 = args.genome2
    organame2 = args.organism2
    abam3 = args.abam3
    bam3 = args.bam3
    genome3 = args.genome3
    organame3 = args.organism3
    sname = args.name
    identity = float(str(args.identity))
    outfile = args.output
    obam1 = args.output_bam1
    aobam1 = args.output_abam1
    obam2 = args.output_bam2
    aobam2 = args.output_abam2
    obam3 = args.output_bam3
    aobam3 = args.output_abam3
    endo1 = float(str(args.endo_dna1))
    endo2 = float(str(args.endo_dna2))
    endo3 = float(str(args.endo_dna3))

    processes = int(args.processes)

    return(abam1,
           bam1,
           genome1,
           organame1,
           abam2,
           bam2,
           genome2,
           organame2,
           abam3,
           bam3,
           genome3,
           organame3,
           sname,
           identity,
           processes,
           outfile,
           obam1,
           aobam1,
           obam2,
           aobam2,
           obam3,
           aobam3,
           endo1,
           endo2,
           endo3)


def getBasename(file_name):
    if ("/") in file_name:
        basename = file_name.split("/")[-1].split(".")[0]
    else:
        basename = file_name.split(".")[0]
    return(basename)


def computeGenomeSize(fasta):
    gs = 0
    with open(fasta, "r") as f:
        for line in f:
            line = line.rstrip()
            if not line.startswith(">"):
                gs += len(line)
    return(gs)


def getNumberMappedReads(bam, id):
    nb_mapped_reads = 0
    min_identity = id
    bamfile = pysam.AlignmentFile(bam, "rb")

    for read in bamfile:
        mismatch = read.get_tag("NM")
        alnLen = read.query_alignment_length
        readLen = read.query_length
        identity = (alnLen - mismatch) / readLen
        if identity >= min_identity:
            nb_mapped_reads += (alnLen - mismatch)
    return(nb_mapped_reads)


def perChromosome(chr, bam, id, commonReads=[]):
    resdic = {}
    min_identity = id
    bamfile = pysam.AlignmentFile(bam, "rb")
    reads = bamfile.fetch(chr, multiple_iterators=True)
    if len(commonReads) == 0:
        for read in reads:
            mismatch = read.get_tag("NM")
            alnLen = read.query_alignment_length
            readLen = read.query_length
            identity = (alnLen - mismatch) / readLen
            if identity >= min_identity:
                resdic[read.query_name] = alnLen - mismatch
    return(resdic)


def getNumberMappedReadsMultiprocess(bam, processes, id, commonReads=[]):
    resdic = {}
    try:
        bamfile = pysam.AlignmentFile(bam, "rb")
    except ValueError:
        return(resdic)
    chrs = bamfile.references

    perChromosomePartial = partial(
        perChromosome, bam=bam, id=id, commonReads=commonReads)
    p = multiprocessing.Pool(processes)
    result = p.map(perChromosomePartial, chrs)
    p.close()
    p.join()
    for i in result:
        resdic.update(i)
    return(resdic)


def writeBam(inbam, outbam, commonReads):
    try:
        bamin = pysam.AlignmentFile(inbam, "rb")
        bamout = pysam.AlignmentFile(outbam, 'w', template=bamin)
        for s in bamin:
            if s.query_name not in commonReads:
                bamout.write(s)
    except ValueError:
        with open("error_readme_"+outbam, "w") as o:
            o.write("Error writing this BAM file, input bam file was erroneous")


def getCommonReads(readsBam1, readsBam2, readsBam3=None):
    reads1 = set(readsBam1)
    reads2 = set(readsBam2)
    if not readsBam3:
        res = list(reads1.intersection(reads2))
    else:
        reads3 = set(readsBam3)
        res = list(reads1.intersection(reads2).intersection(reads3))
    return(res)


def get_total_bp(bamres, common):
    bp_cnt = 0
    for i in bamres.keys():
        if i not in common:
            bp_cnt += bamres[i]
    return(bp_cnt)


def check_endo(endo_dna, organism):
    if endo_dna >= 0 and endo_dna <= 1:
        return True
    else:
        print(f"{endo_dna} is not valid: proportion of endogenous DNA for {organism} should be between 0 and 1. ")
        sys.exit(1)


if __name__ == "__main__":
    ABAM1, BAM1, GENOME1, ORGANAME1, ABAM2, BAM2, GENOME2, ORGANAME2, ABAM3, BAM3, GENOME3, ORGANAME3, SNAME, ID, PROCESSES, OUTFILE, OBAM1, AOBAM1, OBAM2, AOBAM2, OBAM3, AOBAM3, ENDO1, ENDO2, ENDO3 = get_args()

    if BAM1 is None:
        print("Missing BAM file")
        sys.exit(1)
    elif BAM2 is None:
        print("Missing BAM file")
        sys.exit(1)

    bam_basename1 = getBasename(BAM1)
    bam_basename2 = getBasename(BAM2)
    check_endo(ENDO1, ORGANAME1)
    check_endo(ENDO2, ORGANAME2)

    if BAM3:
        bam_basename3 = getBasename(BAM3)
        check_endo(ENDO3, ORGANAME3)

    if not OUTFILE:
        OUTFILE = SNAME + ".out"

    gs1 = computeGenomeSize(GENOME1)
    gs2 = computeGenomeSize(GENOME2)
    if BAM3:
        gs3 = computeGenomeSize(GENOME3)

    bam1_res = getNumberMappedReadsMultiprocess(
        bam=BAM1, processes=PROCESSES, id=ID)
    bam2_res = getNumberMappedReadsMultiprocess(
        bam=BAM2, processes=PROCESSES, id=ID)
    if BAM3:
        bam3_res = getNumberMappedReadsMultiprocess(
            bam=BAM3, processes=PROCESSES, id=ID)

    if ABAM1 and ABAM2:
        abam1_res = getNumberMappedReadsMultiprocess(
            bam=ABAM1, processes=PROCESSES, id=ID)
        abam2_res = getNumberMappedReadsMultiprocess(
            bam=ABAM2, processes=PROCESSES, id=ID)
    if ABAM3:
        abam3_res = getNumberMappedReadsMultiprocess(
            bam=ABAM3, processes=PROCESSES, id=ID)

    reads1 = list(bam1_res.keys())
    reads2 = list(bam2_res.keys())
    if BAM3:
        reads3 = list(bam3_res.keys())
    if not BAM3:
        commonReads = getCommonReads(readsBam1=reads1, readsBam2=reads2)
    else:
        commonReads = getCommonReads(
            readsBam1=reads1, readsBam2=reads2, readsBam3=reads3)

    if ABAM1 and ABAM2:
        areads1 = list(abam1_res.keys())
        areads2 = list(abam2_res.keys())
        if ABAM3:
            areads3 = list(abam3_res.keys())
        if not ABAM3:
            acommonReads = getCommonReads(readsBam1=areads1, readsBam2=areads2)
        else:
            acommonReads = getCommonReads(
                readsBam1=areads1, readsBam2=areads2, readsBam3=areads3)

    if OBAM1 is None:
        outbam1 = bam_basename1 + ".filtered.bam"
    else:
        outbam1 = OBAM1
    if AOBAM1 is None and ABAM1:
        aoutbam1 = bam_basename1 + ".ancient.filtered.bam"
    elif ABAM1:
        aoutbam1 = AOBAM1

    if OBAM2 is None:
        outbam2 = bam_basename2 + ".filtered.bam"
    else:
        outbam2 = OBAM2
    if AOBAM2 is None and ABAM2:
        aoutbam2 = bam_basename2 + ".ancient.filtered.bam"
    elif ABAM2:
        aoutbam2 = AOBAM2

    if OBAM3 is None and BAM3:
        outbam3 = bam_basename3 + ".filtered.bam"
    else:
        outbam3 = OBAM3
    if AOBAM3 is None and ABAM3:
        aoutbam3 = bam_basename3 + ".ancient.filtered.bam"
    elif ABAM3:
        aoutbam3 = AOBAM3

    writeBam(inbam=BAM1, outbam=outbam1, commonReads=commonReads)
    writeBam(inbam=BAM2, outbam=outbam2, commonReads=commonReads)
    if OBAM3:
        writeBam(inbam=BAM3, outbam=outbam3, commonReads=commonReads)

    if ABAM1 and ABAM2:
        writeBam(inbam=ABAM1, outbam=aoutbam1, commonReads=acommonReads)
        writeBam(inbam=ABAM2, outbam=aoutbam2, commonReads=acommonReads)
        if OBAM3:
            writeBam(inbam=ABAM3, outbam=aoutbam3, commonReads=acommonReads)

    nb1 = get_total_bp(bam1_res, commonReads)
    nb2 = get_total_bp(bam2_res, commonReads)
    if BAM3:
        nb3 = get_total_bp(bam3_res, commonReads)

    print("nb1_all", nb1)
    print("nb2_all", nb2)
    if BAM3:
        print("nb3_all", nb3)

    if ABAM1 and ABAM2:
        # Two genomes
        anb1 = get_total_bp(abam1_res, acommonReads)
        nnbp1 = ((anb1 + 1) / gs1) * (1/ENDO1)
        anb2 = get_total_bp(abam2_res, acommonReads)
        nnbp2 = ((anb2 + 1) / gs2) * (1/ENDO2)
        print("nb1_ancient", anb1)
        print("nb2_ancient", anb2)

        NormalizedReadRatio_1 = (nnbp1 / (nnbp1 + nnbp2)) 
        NormalizedReadRatio_2 = (nnbp2 / (nnbp1 + nnbp2)) 
        if BAM3:
            # Three genomes
            anb3 = get_total_bp(abam3_res, acommonReads)
            nnbp3 = ((anb3 + 1) / gs3) * (1/ENDO3)

            print("nb3_ancient", anb3)

            NormalizedReadRatio_1 = (nnbp1 / (nnbp1 + nnbp2 + nnbp3))
            NormalizedReadRatio_2 = (nnbp2 / (nnbp1 + nnbp2 + nnbp3))
            NormalizedReadRatio_3 = (nnbp3 / (nnbp1 + nnbp2 + nnbp3))
    else:
        # Two genomes
        nnbp1 = ((nb1 + 1) / gs1) * (1/ENDO1)
        nnbp2 = ((nb2 + 1) / gs2) * (1/ENDO2)

        NormalizedReadRatio_1 = (nnbp1 / (nnbp1 + nnbp2))
        NormalizedReadRatio_2 = (nnbp2 / (nnbp1 + nnbp2))
        if BAM3:
            # Three genomes
            nnbp3 = ((nb3 + 1) / gs3) * (1/ENDO3)

            NormalizedReadRatio_1 = (nnbp1 / (nnbp1 + nnbp2 + nnbp3))
            NormalizedReadRatio_2 = (nnbp2 / (nnbp1 + nnbp2 + nnbp3))
            NormalizedReadRatio_3 = (nnbp3 / (nnbp1 + nnbp2 + nnbp3))

    if ABAM1 and ABAM2:
        if not BAM3 and not ABAM3:
            # Template output file structure
            # Sample_name,Organism_name1,Organism_name2,Genome1_size,Genome2_size,nb_bp_aligned_genome1,nb_bp_aligned_genome2,nb_ancient_bp_aligned_genome1,nb_ancient_bp_aligned_genome2,normalized_nb_ancient_bp_aligned_genome1,normalized_nb_ancient_bp_aligned_genome2,NormalizedReadRatio_1,NormalizedReadRatio_2
            with open(OUTFILE, 'w') as w:
                w.write("Sample_name,Organism_name1,Organism_name2,Genome1_size,Genome2_size,nb_bp_aligned_genome1,nb_bp_aligned_genome2,nb_ancient_bp_aligned_genome1,nb_ancient_bp_aligned_genome2,normalized_nb_ancient_bp_aligned_genome1,normalized_nb_ancient_bp_aligned_genome2,NormalizedReadRatio_1,NormalizedReadRatio_2\n")
                w.write(
                    f"{SNAME},{ORGANAME1},{ORGANAME2},{gs1},{gs2},{nb1},{nb2},{anb1},{anb2},{nnbp1},{nnbp2},{NormalizedReadRatio_1},{NormalizedReadRatio_2}\n")
        else:
            # Template output file structure
            # Sample_name,Organism_name1,Organism_name2,Organism_name3,Genome1_size,Genome2_size,Genome3_size,nb_bp_aligned_genome1,nb_bp_aligned_genome2,nb_bp_aligned_genome3,nb_ancient_bp_aligned_genome1,nb_ancient_bp_aligned_genome2,nb_ancient_bp_aligned_genome3,normalized_nb_ancient_bp_aligned_genome1,normalized_nb_ancient_bp_aligned_genome2,normalized_nb_ancient_bp_aligned_genome3,NormalizedReadRatio_1,NormalizedReadRatio_2,NormalizedReadRatio_3
            with open(OUTFILE, 'w') as w:
                w.write("Sample_name,Organism_name1,Organism_name2,Organism_name3,Genome1_size,Genome2_size,Genome3_size,nb_bp_aligned_genome1,nb_bp_aligned_genome2,nb_bp_aligned_genome3,nb_ancient_bp_aligned_genome1,nb_ancient_bp_aligned_genome2,nb_ancient_bp_aligned_genome3,normalized_nb_ancient_bp_aligned_genome1,normalized_nb_ancient_bp_aligned_genome2,normalized_nb_ancient_bp_aligned_genome3,NormalizedReadRatio_1,NormalizedReadRatio_2,NormalizedReadRatio_3\n")
                w.write(
                    f"{SNAME},{ORGANAME1},{ORGANAME2},{ORGANAME3},{gs1},{gs2},{gs3},{nb1},{nb2},{nb3},{anb1},{anb2},{anb3},{nnbp1},{nnbp2},{nnbp3},{NormalizedReadRatio_1},{NormalizedReadRatio_2},{NormalizedReadRatio_3}\n")
    else:
        if not BAM3:
            # Template output file structure
            # Sample_name,Organism_name1,Organism_name2,Genome1_size,Genome2_size,nb_bp_aligned_genome1,nb_bp_aligned_genome2,normalized_nb_bp_aligned_genome1,normalized_nb_bp_aligned_genome2,NormalizedReadRatio_1,NormalizedReadRatio_2
            with open(OUTFILE, 'w') as w:
                w.write("Sample_name, Organism_name1, Organism_name2, Genome1_size, Genome2_size, nb_bp_aligned_genome1, nb_bp_aligned_genome2, normalized_nb_bp_aligned_genome1, normalized_nb_bp_aligned_genome2, NormalizedReadRatio_1, NormalizedReadRatio_2\n")
                w.write(
                    f"{SNAME},{ORGANAME1},{ORGANAME2},{gs1},{gs2},{nb1},{nb2},{nnbp1},{nnbp2},{NormalizedReadRatio_1},{NormalizedReadRatio_2}\n")
        else:
            # Template output file structure
            # Sample_name,Organism_name1,Organism_name2,Organism_name3,Genome1_size,Genome2_size,Genome3_size,nb_bp_aligned_genome1,nb_bp_aligned_genome2,nb_bp_aligned_genome3,normalized_nb_bp_aligned_genome1,normalized_nb_bp_aligned_genome2,normalized_nb_bp_aligned_genome3,NormalizedReadRatio_1,NormalizedReadRatio_2,NormalizedReadRatio_3
            with open(OUTFILE, 'w') as w:
                w.write("Sample_name,Organism_name1,Organism_name2,Organism_name3,Genome1_size,Genome2_size,Genome3_size,nb_bp_aligned_genome1,nb_bp_aligned_genome2,nb_bp_aligned_genome3,normalized_nb_bp_aligned_genome1,normalized_nb_bp_aligned_genome2,normalized_nb_bp_aligned_genome3,NormalizedReadRatio_1,NormalizedReadRatio_2,NormalizedReadRatio_3\n")
                w.write(
                    f"{SNAME},{ORGANAME1},{ORGANAME2},{ORGANAME3},{gs1},{gs2},{gs3},{nb1},{nb2},{nb3},{nnbp1},{nnbp2},{nnbp3},{NormalizedReadRatio_1},{NormalizedReadRatio_2},{NormalizedReadRatio_3}\n")
