import torch
import pytorch_lightning as pl
from mydpr.model.biencoder import MyEncoder
from mydpr.dataset.cath35 import PdDataModule
from absl import logging
import sys
import os
import argparse
import faiss
import time
import math
import pandas as pd
import numpy as np
import phylopandas.phylopandas as ph
import pathlib

ckpt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "cpu_model/fastmsa-cpu.ckpt")
# ckpt_path = "cpu_model/fastmsa-cpu.ckpt"   #-> modified by Sheng Wang at 2022.06.14
input_path = "./input_test.fasta"

qjackhmmer = os.path.join(os.path.dirname(os.path.abspath(__file__)), "bin/qjackhmmer")
out_path = "./testout/"
search_batch = 10
tar_num = 400000
iter_num = 1


def gen_query(fasta_file_path, out_dir):
    os.makedirs(os.path.join(out_dir, "seq"), exist_ok=True)
    df = ph.read_fasta(fasta_file_path, use_uids=False)
    tot_num = len(df)
    for i in range(tot_num):
        seq_slice = df.iloc[i]
        filename = seq_slice.id
        seq_slice.phylo.to_fasta(os.path.join(out_dir, 'seq', filename + '.fasta'), id_col='id')


def my_align(out_dir, iter_num):
    qlist = os.listdir(os.path.join(out_dir, 'seq'))
    os.makedirs(os.path.join(out_dir, "res"), exist_ok=True)
    for fp in qlist:
        qid = fp.split('.')[0]
        cmd = "%s -B %s --noali --incE 1e-3 -E 1e-3 --cpu 32 -N %d --F1 0.0005 --F2 5e-05 --F3 5e-07 > /dev/null" % (
        qjackhmmer, str(os.path.join(out_dir, "res", "%s.a3m" % qid)), iter_num,
        str(os.path.join(out_dir, "seq", "%s.fasta" % qid)), str(os.path.join(out_dir, "res", "%s.fasta" % qid)))


def retrieve_result(input_path, out_path, tar_num, iter_num, idx_path, dm_path):
    s0 = time.time()

    logging.info("Start mkdir!!!")
    gen_query(input_path, out_path)
    s1 = time.time()
    logging.info("Mkdir output cost: %f s" % (s1 - s0))

    index = faiss.read_index(idx_path)
    s2 = time.time()
    logging.info("Load index cost: %f s" % (s2 - s1))
    df = pd.read_pickle(dm_path)

    model = MyEncoder.load_from_checkpoint(checkpoint_path=ckpt_path)
    ds = PdDataModule(input_path, 40, model.alphabet)

    s3 = time.time()
    logging.info("Load ckp cost: %f s" % (s3 - s2))
    trainer = pl.Trainer(gpus=1)  # gpus=[0])
    ret = trainer.predict(model, datamodule=ds, return_predictions=True)
    trainer.save_checkpoint(ckpt_path)
    s4 = time.time()
    logging.info("Train predict cost: %f s" % (s4 - s3))
    # names, qebd = ret[0]

    tmp1 = []
    tmp2 = []
    for i in ret:
        n1, q1 = i
        tmp1 += n1
        q1 = torch.tensor(q1).numpy()
        tmp2.append(q1)
    encoded = np.concatenate(tmp2, axis=0)
    # encoded = np.concatenate([t.cpu().numpy() for t in tmp2]) 
    names = tmp1
    # logging.info(encoded.shape)

    # encoded = qebd.numpy()
    logging.info("prepared model")
    s5 = time.time()
    logging.info("Encode model cost: %f s" % (s5 - s4))

    os.makedirs(os.path.join(out_path, "db"), exist_ok=True)
    for i in range(math.ceil(encoded.shape[0] / search_batch)):
        scores, idxes = index.search(encoded[i * search_batch:(i + 1) * search_batch], tar_num)
        idx_batch = len(idxes)
        for j in range(idx_batch):
            tar_idx = idxes[j]
            res = df.iloc[tar_idx]
            res.phylo.to_fasta_dev(os.path.join(out_path, "db", names[i * search_batch + j] + '.fasta'))

    end = time.time()
    logging.info("Time for predict %d : %f s" % (tar_num, end - s5))
    logging.info("============================================")


def filter_and_save(input_file, output_file):
    with open(input_file, 'r') as infile, open(output_file, 'w') as outfile:
        alignment = ""
        for line in infile:
            if line.startswith('>'):
                if len(alignment) > 0:
                    outfile.write(alignment + '\n')
                # Lines starting with '>' are unchanged
                outfile.write(line)
                alignment = ""
            else:
                # Remove lowercase letters from lines not starting with '>'
                filtered_line = ''.join(char for char in line if not char.islower())
                alignment += filtered_line.rstrip('\n')
        if len(alignment) > 0:
            outfile.write(alignment + '\n')

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='fastMSA do homolog retrieval.')
    parser.add_argument("-i", "--input_path", default=input_path,
                        help="path of the fasta file containing query sequences")
    parser.add_argument("-d", "--database_path", default="./output/agg/",
                        help="path of dir containing database embedding and db converted to DataFrame")
    parser.add_argument("-o", "--output_a3m_path", default=out_path, help="path to output msas")
    parser.add_argument("-n", "--num", default=tar_num, type=int, help="retrieve num")
    parser.add_argument("-r", "--iters", default=iter_num, type=int, help="num of iters by QJackHMMER")
    args = parser.parse_args()

    s_init = time.time()

    output_path = pathlib.Path(args.output_a3m_path).resolve().parent
    input_seq_name = pathlib.Path(args.input_path).name

    temp_output_path = os.path.join(output_path, 'tmp')
    homologs_fasta_path = temp_output_path + "/homologs.fasta"
    homologs_a3m_temp_path = temp_output_path + "/homologs_temp.a3m"

    # run retrival on each chunk sequentially to generate homologs
    chunk_list = ["UR_A", "UR_B", "UR_C", "UR_D", "UR_E", "UR_F", "UR_G"]

    for chunk in chunk_list:
        out_path = os.path.join(temp_output_path, chunk)
        idx_path = os.path.join(args.database_path, chunk, "index-ebd.index")
        dm_path = os.path.join(args.database_path, chunk, "df-ebd.pkl")
        # logging.info(str(input_path), str(out_path), tar_num, iter_num, idx_path, dm_path)
        outfile = temp_output_path + "/" + chunk + "/db/" + input_seq_name
        print(outfile)
        if os.path.exists(outfile):
            continue
        retrieve_result(args.input_path, out_path, args.num, args.iters, idx_path, dm_path)

    # concatenate the chunks of results into one fasta file.
    with open(homologs_fasta_path, 'w') as out_file:
        for chunk in chunk_list:
            targetname = open(args.input_path).readlines()[0].rstrip('\n')[1:]
            with open(temp_output_path + "/" + chunk + "/db/" + targetname + '.fasta', 'r') as in_file:
                out_file.write(in_file.read())

    # infile_list = []
    # for chunk in chunk_list:
    #     in_file = temp_output_path + "/" + chunk + "/db/" + input_seq_name
    #     infile_list += [in_file]
    # infile_list_str = ' '.join(infile_list)
    # os.system(f"cat {infile_list_str} > {homologs_fasta_path}")
    

    # run qjackhmmer to generate msa from the retrieved homologs
    if not os.path.exists(homologs_a3m_temp_path):
        os.system(f"{qjackhmmer} -B {homologs_a3m_temp_path} --noali --incE 1e-3 -E 1e-3 --cpu 32 -N 3 --F1 0.0005 --F2 5e-05 --F3 5e-07 {args.input_path} {homologs_fasta_path}")

    # filter a3m to remove small letter alphabets in sequences
    filter_and_save(homologs_a3m_temp_path, args.output_a3m_path)
    s_final = time.time()
    logging.info("Total Cost: %f s" % (s_final - s_init))
