import sys
from pathlib import Path

import pandas as pd

from snakemake.utils import Paramspace
from snakemake.utils import min_version


min_version("7.10.0")


# setup config
configfile: "config/config.yaml"


# setup paramspace
df_params = pd.read_csv(config["params_path"], comment="#")
paramspace = Paramspace(df_params, filename_params="*", filename_sep="__")

# method selection
if config["method_list"] is None:
    (method_list,) = glob_wildcards(
        srcdir("../resources/method_definitions/{method}.py")
    )
else:
    method_list = config["method_list"]

print(method_list, file=sys.stderr)


# group methods into local (produces VCF) and global (produces FASTA)
method_list_local = []
method_list_global = []

for method in method_list:
    group_prefix = "# GROUP:"

    fname_method = srcdir(f"../resources/method_definitions/{method}.py")
    with open(fname_method) as fd:
        for line in fd.readlines():
            if not line.startswith(group_prefix):
                continue

            group = line[len(group_prefix) :].strip()
            break

    if group == "local":
        method_list_local.append(method)
    elif group == "global":
        method_list_global.append(method)
    else:
        raise RuntimeError(f"Invalid group '{group}' for method '{method}'")


# misc setup
sequencing_mode = paramspace["seq_mode"][0]
if sequencing_mode.startswith("amplicon"):
    sequencing_mode = sequencing_mode.split(":")[
        0
    ]  # TODO: Allow different sequencing_modes in parameter file


# helper functions
def get_generated_conda_env(wildcards, input):
    # retrieve conda dependencies from script
    conda_dep_prefix = "# CONDA:"

    conda_dep_list = []
    with open(input.script) as fd:
        for line in fd.readlines():
            if not line.startswith(conda_dep_prefix):
                continue

            conda_dep_list.append(line[len(conda_dep_prefix) :].strip())

    # retrieve pip dependencies from script
    pip_dep_prefix = "# PIP:"

    pip_dep_list = []
    with open(input.script) as fd:
        for line in fd.readlines():
            if not line.startswith(pip_dep_prefix):
                continue

            pip_dep_list.append(line[len(pip_dep_prefix) :].strip())

    # format conda env file
    conda_env = """channels:
  - hcc
  - broad-viral
  - conda-forge
  - bioconda
  - defaults
dependencies:"""

    if len(conda_dep_list) == 0 and len(pip_dep_list) == 0:
        conda_env += " []"
    else:
        if len(conda_dep_list) > 0:
            conda_env += "\n  - " + "\n  - ".join(conda_dep_list)

        if len(pip_dep_list) > 0:
            conda_env += "\n  - python=3.9\n"
            conda_env += "  - pip\n"
            conda_env += "  - pip:\n    - " + "\n    - ".join(pip_dep_list)

    print(f'Generated conda env for "{input.script}":', file=sys.stderr)
    print(conda_env, file=sys.stderr)

    # save conda env
    conda_prefix = Path("results/envs/")
    conda_prefix.mkdir(parents=True, exist_ok=True)

    conda_env_path = conda_prefix / f"{wildcards.method}.yaml"
    conda_env_path.write_text(conda_env)

    return str(conda_env_path.resolve())


# rule definitions
rule all:
    input:
        "results/performance_measures/local/" if len(method_list_local) > 0 else [],
        "results/performance_measures/global/" if len(method_list_global) > 0 else [],
        "results/haplo_stats/summary.csv",
        #"results/haplotype_populations/" if config['haplotype_generation'] == 'distance' else None,
        expand(
            "results/read_statistics/{params}/replicates/{replicate}/",
            params=paramspace.instance_patterns,
            replicate=range(config["replicate_count"]),
        ),


rule download_pbsim2_model:
    output:
        fname_pbsim2_model="results/pbsim2/P6C4.model",
    shell:
        """
        dname="$(dirname {output.fname_pbsim2_model})"
        mkdir -p "$dname"

        wget \
            -O "{output.fname_pbsim2_model}" \
            https://raw.githubusercontent.com/yukiteruono/pbsim2/master/data/P6C4.model
        """


rule generate_haplotypes:
    output:
        fname_reference=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/reference.fasta",
        fname_groundtruth=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/ground_truth.csv",
        fname_fasta=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/haplotypes.fasta",
        dname_work=directory(
            f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/work/"
        ),
    params:
        params=paramspace.instance,
        haplotype_generation=config["haplotype_generation"],
        master_seq_path=config["master_seq_path"],
    conda:
        "envs/simulate_reads.yaml"
    wildcard_constraints:
        seq_mode=r"(?!real_data).*",
        seq_mode_param=r"(?!real_data).*",
    script:
        "scripts/generate_haplotypes.py"


rule shotgun_simulation:
    input:
        dname_work=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/work/",
        pbsim2_model="results/pbsim2/P6C4.model",
    output:
        fname_fastq=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/reads.shotgun.fastq",
        fname_bam=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/reads.shotgun.bam",
    params:
        params=paramspace.instance,
        haplotype_generation=config["haplotype_generation"],
    conda:
        "envs/simulate_reads.yaml"
    resources:
        mem_mb=5_000,
        runtime=60 * 4,
    script:
        "scripts/shotgun_simulation.py"


if config["haplotype_generation"] == "distance":

    rule collect_haplo_populations_visualizations:
        input:
            dname_work=[
                f"results/simulated_reads/{params}/replicates/{replicate}/work/"
                for params in paramspace.instance_patterns
                for replicate in range(config["replicate_count"])
            ],
        output:
            dname_out=directory(f"results/haplotype_populations/"),
        script:
            "scripts/collect_haplo_populations.py"


rule simulate_scheme:
    conda:
        "envs/split.yaml"
    input:
        fname_reference=rules.generate_haplotypes.output.fname_reference,
    output:
        fname_insert_bed=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/scheme/reference.insert.bed",
        dname_out=directory(
            f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/scheme/"
        ),
    params:
        params=paramspace.instance,
    wildcard_constraints:
        seq_mode=r"(?!real_data).*",
        seq_mode_param=r"(?!real_data).*",
    script:
        "scripts/simulate_scheme.py"


rule amplicon_simulation:
    input:
        dname_work=rules.generate_haplotypes.output.dname_work,
        fname_insert_bed=rules.simulate_scheme.output.fname_insert_bed,
    output:
        fname_fastq_R1=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/reads.amplicon.R1.fastq",
        fname_fastq_R2=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/reads.amplicon.R2.fastq",
        fname_bam=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/reads_merged.amplicon.bam",
    params:
        params=paramspace.instance,
        haplotype_generation=config["haplotype_generation"],
    wildcard_constraints:
        seq_mode=r"(?!real_data).*",
        seq_mode_param=r"(?!real_data).*",
    conda:
        "envs/simulate_reads.yaml"
    script:
        "scripts/amplicon_simulation.py"


rule provide_2_SARS_CoV_2_mix:
    output:
        fname_cram=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/reads.real_data.cram",
        fname_reference=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/reference.fasta",
        fname_groundtruth=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/ground_truth.csv",
        fname_insert_bed=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/scheme/reference.insert.bed",
        dname_work=directory(
            f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/scheme/"
        ),
    resources:
        mem_mb=5_000,
        runtime=60 * 4,
    wildcard_constraints:
        read_length="2-SARS-CoV-2-mix",
    shell:
        """
        tmp="{wildcards.read_length}"
        parts=(${{tmp//@/ }})
        source="{wildcards.read_length}"
        subsample_frac="${wildcards.coverage}"
        sample="{wildcards.genome_size}"
        echo "Using $subsample_frac of $source of sample $sample"

        fname_cram="resources/experimental_data/SARS-CoV-2-2-strain-mix/"$sample"/"$sample".cram"
        fname_expected_variants="resources/experimental_data/SARS-CoV-2-2-strain-mix/"$sample"/"$sample"_expected_variants.csv"
        fname_reference="resources/experimental_data/SARS-CoV-2-2-strain-mix/reference.fasta"
        fname_insert_bed="resources/experimental_data/SARS-CoV-2-2-strain-mix/nCoV-2019.insert.bed"

        mkdir -p "{output.dname_work}"

        if [[ "$source" == "2-SARS-CoV-2-mix" ]]; then
            cp "$fname_cram" "{output.fname_cram}"
            cp "$fname_expected_variants" "{output.fname_groundtruth}"
            cp "$fname_reference" "{output.fname_reference}"
            cp "$fname_insert_bed" "{output.fname_insert_bed}"
        fi
        """


rule split:
    conda:
        "envs/split.yaml"
    input:
        fname_aln=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/reads.real_data.cram",
    output:
        fname_fastq_R1=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/reads.amplicon.R1.fastq",
        fname_fastq_R2=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/reads.amplicon.R2.fastq",
    wildcard_constraints:
        read_length="2-SARS-CoV-2-mix",
    shell:
        "samtools collate -O {input.fname_aln} | samtools fastq -1 {output.fname_fastq_R1} -2 {output.fname_fastq_R2} -"


rule flash:
    input:
        fname_fastq_R1=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/reads.amplicon.R1.fastq",
        fname_fastq_R2=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/reads.amplicon.R2.fastq",
    output:
        fname_fastq_merged=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/flash/reads_merged.extendedFrags.fastq",
        dname_out=directory(
            f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/flash/"
        ),
    conda:
        "envs/split.yaml"
    shell:
        """
        flash \
          {input.fname_fastq_R1} \
          {input.fname_fastq_R2} \
          -r 250 \
          -f 400 \
          -s 40 \
          --allow-outies \
          --output-prefix=reads_merged \
          --output-directory={output.dname_out}
        """


rule alignment_merged:
    input:
        fname_reference=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/reference.fasta",
        fname_fastq_merged=rules.flash.output.fname_fastq_merged,
    output:
        fname_bam=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/reads.amplicon.bam",
    conda:
        "envs/split.yaml"
    shell:
        """
        bwa index {input.fname_reference}
        bwa mem {input.fname_reference} {input.fname_fastq_merged} > {output.fname_bam}
        samtools sort -o {output.fname_bam} {output.fname_bam}
        """


rule provide_real_data:
    output:
        fname_bam=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/reads.real_data.bam",
        fname_reference=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/reference.fasta",
        fname_groundtruth=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/ground_truth.csv",
        fname_fasta=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/haplotypes.fasta",
        dname_work=directory(
            f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/work/"
        ),
    conda:
        "envs/real_data.yaml"
    resources:
        mem_mb=5_000,
        runtime=60 * 4,
    wildcard_constraints:
        seq_mode="real_data",
    shell:
        """
        tmp="{wildcards.seq_mode_param}"
        parts=(${{tmp//@/ }})
        source="${{parts[0]}}"
        subsample_frac="${{parts[1]}}"
        echo "Using $subsample_frac of $source"

        mkdir -p "{output.dname_work}"

        if [[ "$source" == "5-virus-mix" ]]; then
            # download all references
            curl --output-dir "{output.dname_work}" -O "https://raw.githubusercontent.com/cbg-ethz/5-virus-mix/master/data/REF.fasta"

            # only keep HXB2 reference
            sed -ne '3p;4p' "{output.dname_work}/REF.fasta" > "{output.fname_reference}"

            # download true haplotypes and set frequencies (source: table 1 in https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4132706/)
            curl --output-dir "{output.dname_work}" -O "https://raw.githubusercontent.com/cbg-ethz/5-virus-mix/master/data/5VM.fasta"
            mv "{output.dname_work}/5VM.fasta" "{output.fname_fasta}"

            sed \
              -i \
              -e "s/>89.6/>89.6 freq:0.221/" \
              -e "s/>HXB2/>HXB2 freq:0.112/" \
              -e "s/>JRCSF/>JRCSF freq:0.28/" \
              -e "s/>NL43/>NL43 freq:0.273/" \
              -e "s/>YU2/>YU2 freq:0.111/" \
              "{output.fname_fasta}"

            # create ground truth information (TODO)
            touch {output.fname_groundtruth}

            # download fastq
            curl --output-dir "{output.dname_work}" -O "ftp://ftp.sra.ebi.ac.uk/vol1/fastq/SRR961/SRR961514/SRR961514_1.fastq.gz"
            curl --output-dir "{output.dname_work}" -O "ftp://ftp.sra.ebi.ac.uk/vol1/fastq/SRR961/SRR961514/SRR961514_2.fastq.gz"

            # align reads to reference
            bwa index {output.fname_reference}
            bwa mem \
                -t {threads} \
                {output.fname_reference} \
                "{output.dname_work}/SRR961514_1.fastq.gz" "{output.dname_work}/SRR961514_2.fastq.gz" \
                > {output.fname_bam}
            samtools sort -o {output.fname_bam} {output.fname_bam}

            # subsample
            tmpfile="{output.dname_work}/tmp.bam"
            samtools view \
              --subsample "$subsample_frac" \
              --subsample-seed "{wildcards.replicate}" \
              -o "$tmpfile" \
              {output.fname_bam}
            mv "$tmpfile" {output.fname_bam}
        fi
        """


rule samtools_index:
    input:
        f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/reads.{{seq_mode}}.bam",
    output:
        f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/reads.{{seq_mode}}.bam.bai",
    wrapper:
        "v1.3.2/bio/samtools/index"


rule read_statistics:
    input:
        fname_bam=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/reads.{{seq_mode}}.bam",
        fname_bam_index=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/reads.{{seq_mode}}.bam.bai",
    output:
        outdir=directory(
            f"results/read_statistics/{paramspace.wildcard_pattern}/replicates/{{replicate}}/"
        ),
    conda:
        "envs/read_statistics.yaml"
    script:
        "scripts/read_statistics.py"


rule run_method_local:
    input:
        script=srcdir("../resources/method_definitions/{method}.py"),
        fname_bam=lambda wildcards: f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/reads.{{seq_mode}}.bam"
        if wildcards.seq_mode == "amplicon"
        else f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/reads.{{seq_mode}}.bam",
        fname_bam_index=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/reads.{{seq_mode}}.bam.bai",
        fname_reference=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/reference.fasta",
        fname_insert_bed=lambda wildcards: f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/scheme/reference.insert.bed"
        if wildcards.seq_mode == "amplicon"
        else [],
    output:
        fname_result=f"results/method_runs/{paramspace.wildcard_pattern}/{{method}}/replicates/{{replicate}}/snvs.vcf",
        fname_status=touch(
            f"results/method_runs/{paramspace.wildcard_pattern}/{{method,{'|'.join(['markertoavoidemptyregex'] + method_list_local)}}}/replicates/{{replicate}}/status.txt"
        ),
        dname_work=directory(
            f"results/method_runs/{paramspace.wildcard_pattern}/{{method}}/replicates/{{replicate}}/work/"
        ),
    benchmark:
        f"results/method_runs/{paramspace.wildcard_pattern}/{{method,{'|'.join(['markertoavoidemptyregex'] + method_list_local)}}}/replicates/{{replicate}}/benchmark.tsv"
    params:
        script_path=lambda wildcards, input: input.script,
    conda:
        get_generated_conda_env
    threads: 10
    resources:
        mem_mb=2_000,
        runtime=60 * 24,
    script:
        "{params.script_path}"


use rule run_method_local as run_method_global with:
    output:
        fname_result=f"results/method_runs/{paramspace.wildcard_pattern}/{{method}}/replicates/{{replicate}}/haplotypes.fasta",
        fname_status=touch(
            f"results/method_runs/{paramspace.wildcard_pattern}/{{method,{'|'.join(['markertoavoidemptyregex'] + method_list_global)}}}/replicates/{{replicate}}/status.txt"
        ),
        dname_work=directory(
            f"results/method_runs/{paramspace.wildcard_pattern}/{{method}}/replicates/{{replicate}}/work/"
        ),
    benchmark:
        f"results/method_runs/{paramspace.wildcard_pattern}/{{method,{'|'.join(['markertoavoidemptyregex'] + method_list_global)}}}/replicates/{{replicate}}/benchmark.tsv"
    threads: 10
    resources:
        mem_mb=lambda wildcards, threads, attempt: int((15_000 * attempt) / threads),
        runtime=60 * 24,
    retries: 1


rule haplotypes_stats:
    input:
        fname_reference=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/reference.fasta",
        fname_groundtruth=f"results/simulated_reads/{paramspace.wildcard_pattern}/replicates/{{replicate}}/ground_truth.csv",
    output:
        fname=f"results/haplo_stats/{paramspace.wildcard_pattern}/replicates/{{replicate}}/haplotypes_stats.csv",
    conda:
        "envs/haplotype_stats.yaml"
    script:
        "scripts/haplotypes_stats.py"


rule collect_haplotypes_stats:
    input:
        haplostats_list=[
            f"results/haplo_stats/{params}/replicates/{replicate}/haplotypes_stats.csv"
            for params in paramspace.instance_patterns
            for replicate in range(config["replicate_count"])
        ],
    output:
        fname_haplostats_summary=f"results/haplo_stats/summary.csv",
    run:
        import pandas as pd

        merged_deletions_csv = pd.concat(
            [pd.read_csv(path_del) for path_del in input.haplostats_list]
        )
        merged_deletions_csv.to_csv(output.fname_haplostats_summary)


rule performance_measures_local:
    input:
        vcf_list=[
            f"results/method_runs/{params}/{method}/replicates/{replicate}/snvs.vcf"
            for params in paramspace.instance_patterns
            for method in method_list_local
            for replicate in range(config["replicate_count"])
        ],
        # we also include `method_list_local` to align with `vcf_list`
        groundtruth_list=[
            f"results/simulated_reads/{params}/replicates/{replicate}/ground_truth.csv"
            for params in paramspace.instance_patterns
            for method in method_list_local
            for replicate in range(config["replicate_count"])
        ],
        benchmark_list=[
            f"results/method_runs/{params}/{method}/replicates/{replicate}/benchmark.tsv"
            for params in paramspace.instance_patterns
            for method in method_list_local
            for replicate in range(config["replicate_count"])
        ],
        haplostats_list=[
            f"results/haplo_stats/{params}/replicates/{replicate}/haplotypes_stats.csv"
            for params in paramspace.instance_patterns
            for replicate in range(config["replicate_count"])
        ],
    output:
        dname_out=directory("results/performance_measures/local/"),
    conda:
        "envs/performance_measures.yaml"
    script:
        "scripts/performance_measures_local.py"


rule performance_measures_global:
    input:
        predicted_haplos_list=[
            f"results/method_runs/{params}/{method}/replicates/{replicate}/haplotypes.fasta"
            for params in paramspace.instance_patterns
            for method in method_list_global
            for replicate in range(config["replicate_count"])
        ],
        true_haplos_list=[
            f"results/simulated_reads/{params}/replicates/{replicate}/haplotypes.fasta"
            for params in paramspace.instance_patterns
            for replicate in range(config["replicate_count"])
        ],
        haplostats_list=[
            f"results/haplo_stats/{params}/replicates/{replicate}/haplotypes_stats.csv"
            for params in paramspace.instance_patterns
            for replicate in range(config["replicate_count"])
        ],
        runstatus_list=[
            f"results/method_runs/{params}/{method}/replicates/{replicate}/status.txt"
            for params in paramspace.instance_patterns
            for method in method_list_global
            for replicate in range(config["replicate_count"])
        ],
        benchmark_list=[
            f"results/method_runs/{params}/{method}/replicates/{replicate}/benchmark.tsv"
            for params in paramspace.instance_patterns
            for method in method_list_global
            for replicate in range(config["replicate_count"])
        ],
    output:
        dname_out=directory("results/performance_measures/global/"),
    conda:
        "envs/performance_measures.yaml"
    params:
        method_list_global=method_list_global,
    threads: 10
    resources:
        mem_mb=500,
        runtime=60 * 24,
    script:
        "scripts/performance_measures_global.py"
