import os
import re
import sys
import tempfile
import pandas as pd
import numpy as np
from pathlib import Path

from snakemake.utils import logger, min_version

workflow_folder = os.path.dirname(os.path.abspath(workflow.snakefile))

sys.path.append(os.path.join(workflow_folder, "scripts"))
import utils


# add default config
# comand line adds user config
configfile: os.path.join(workflow_folder, "..", "config", "default_config.yaml")


# add defualt values from python (TODO: replace this)
from atlas.make_config import update_config as atlas_update_config

config = atlas_update_config(config)
from atlas.default_values import *  # LOAD HARDCODED values and EGGNOG_HEADER


# minimum required snakemake version
min_version("6.1")


container: "docker://continuumio/miniconda3:4.4.10"


wildcard_constraints:
    binner="[A-Za-z]+",


include: "rules/sample_table.smk"
include: "rules/download.smk"  # contains hard coded variables
include: "rules/qc.smk"
include: "rules/screen.smk"  # expects function get_input_fastq defined in qc
include: "rules/assemble.smk"
include: "rules/binning.smk"
include: "rules/derep.smk"
include: "rules/genomes.smk"
include: "rules/dram.smk"
include: "rules/genecatalog.smk"
include: "rules/sra.smk"
include: "rules/gtdbtk.smk"
include: "rules/cobinning.smk"
include: "rules/strains.smk"
include: "rules/patch.smk"


CONDAENV = "envs"  # overwrite definition in download.smk


localrules:
    all,
    qc,
    assembly_one_sample,
    assembly,
    genomes,


rule all:
    input:
        "finished_QC",
        "finished_assembly",
        "finished_binning",
        "finished_genomes",
        "finished_genecatalog",
        "genomes/annotations/gene2genome.parquet",


def get_gene_catalog_input():
    annotation_file_names = {
        "eggNOG": "Genecatalog/annotations/eggNOG.parquet",
        "dram": "Genecatalog/annotations/dram",
        "single_copy": expand(
            "Genecatalog/annotation/single_copy_genes_{domain}.tsv",
            domain=["bacteria", "archaea"],
        ),
    }

    annotations_requested = config.get("gene_annotations", [])

    try:
        annotations_files = ["Genecatalog/counts/median_coverage.h5"] + [
            annotation_file_names[key] for key in annotations_requested
        ]

    except Exception as e:
        raise IOError(
            "Error in gene_annotations requested, check config file 'gene_annotations' "
        ) from e

    return annotations_files


rule genecatalog:
    input:
        "Genecatalog/gene_catalog.fna",
        "Genecatalog/gene_catalog.faa",
        "Genecatalog/clustering/orf_info.parquet",
        get_gene_catalog_input(),
    output:
        temp(touch("finished_genecatalog")),


rule strains:
    input:
        "strains/comparison",


def get_genome_annotations():
    annotation_file_names = {
        "gtdb_tree": "genomes/tree/finished_gtdb_trees",
        "gtdb_taxonomy": "genomes/taxonomy/gtdb_taxonomy.tsv",
        "genes": "genomes/annotations/genes/predicted",
        "kegg_modules": "genomes/annotations/dram/kegg_modules.tsv",
        "dram": "genomes/annotations/dram/distil",
    }

    annotations_requested = config["annotations"]

    try:
        annotations_files = ["genomes/genome_quality.tsv"] + [
            annotation_file_names[an] for an in annotations_requested
        ]

    except Exception as e:
        raise IOError(
            "Error in annotations requested, check config file 'annotations' "
        ) from e

    return annotations_files


rule genomes:
    input:
        "genomes/counts/median_coverage_genomes.parquet",
        "genomes/counts/counts_genomes.parquet",
        "genomes/clustering/allbins2genome.tsv",
        "reports/genome_mapping/results.html",
        *get_genome_annotations(),
        "finished_binning",
    output:
        temp(touch("finished_genomes")),


rule quantify_genomes:
    input:
        "genomes/counts/median_coverage_genomes.parquet",
        expand("genomes/counts/counts_genomes.parquet"),


rule binning:
    input:
        "Binning/{binner}/genome_similarities.parquet".format(
            binner=config["final_binner"]
        ),
        "Binning/{binner}/bins2species.tsv".format(binner=config["final_binner"]),
        "Binning/{binner}/bin_info.tsv".format(binner=config["final_binner"]),
        expand("reports/bin_report_{binner}.html", binner=config["final_binner"]),
        "finished_assembly",
    output:
        temp(touch("finished_binning")),


rule assembly_one_sample:
    input:
        get_assembly,
        "{sample}/sequence_alignment/{sample}.bam",
        "{sample}/assembly/contig_stats/postfilter_coverage_stats.txt",
        "{sample}/assembly/contig_stats/final_contig_stats.txt",
    output:
        touch("{sample}/finished_assembly"),


rule assembly:
    input:
        expand("{sample}/finished_assembly", sample=SAMPLES),
        "reports/assembly_report.html",
        "finished_QC",
    output:
        touch("finished_assembly"),


rule qc:
    input:
        expand("{sample}/sequence_quality_control/finished_QC", sample=SAMPLES),
        read_counts="stats/read_counts.tsv",
        read_length_stats=(
            ["stats/insert_stats.tsv", "stats/read_length_stats.tsv"]
            if PAIRED_END
            else "stats/read_length_stats.tsv"
        ),
        report="reports/QC_report.html",
    output:
        touch("finished_QC"),


rule screen:
    input:
        "QC/screen/sketch_comparison.tsv.gz",


# overwrite commands in rules/download.snakefile
onsuccess:
    print("ATLAS finished")
    print("The last rule shows you the main output files")


onerror:
    print("Note the path to the log file for debugging.")
    print("Documentation is available at: https://metagenome-atlas.readthedocs.io")
    print("Issues can be raised at: https://github.com/metagenome-atlas/atlas/issues")


### pepulate resources for rules that don't have

for r in workflow.rules:
    if not "mem_mb" in r.resources:
        # default
        r.resources["mem_mb"] = config["mem"] * 1000

    # add time if ot present. Simple jobs use simple time

    if "time_min" not in r.resources:
        r.resources["time_min"] = config["runtime"]["default"] * 60

    if not "runtime" in r.resources:
        r.resources["runtime"] = r.resources["time_min"]
