#!/usr/bin/env python

import sys, argparse
import re
def parse_g4_coords(g4, g4_coords, num_feat_path, feat_numbers, thr):
    fh = open(g4)
    for line in fh:
        line = line.rstrip()
        if not line:
            continue
        if line.startswith('#'):
            continue
        elif line.startswith('>'):

            fields = line.rstrip().split('.')
            name = fields[0][1:]
            feat_num_file = num_feat_path + '/' + name
            feat_numbers = parse_feat_numbers(feat_num_file, feat_numbers, name)
        else:
            campi = line.split('\t')
            
            if int(campi[2]) < thr:
                continue
            else:
                if name in g4_coords.keys():
                    g4_coords[name].append((campi[0],campi[1],campi[2],campi[3],campi[4]))
                else:
                    g4_coords[name] = [(campi[0],campi[1],campi[2],campi[3],campi[4])]
                # end if
            # end if
        # end if
    # end for
    fh.close()
    return(g4_coords, feat_numbers)

# end def

def parse_feat_numbers(feat_num_file, feat_numbers, name):
    try:
        fh1 = open(feat_num_file)
        for line in fh1:
            line = line.rstrip()
            if not line:
                continue
            elif line.startswith('>'):
                continue
            else:
                fields = line.split('\t')
                if name in feat_numbers:
                    feat_numbers[name][fields[0]] = int(fields[1])
                else:
                    feat_numbers[name] = {}
                    feat_numbers[name][fields[0]] = int(fields[1])
                # end if
            # end if
        # end for
    except:
        stderr.write("Unable to find %s file..." % feat_num_file)
        sys.exit(0)
    return(feat_numbers)
# end def

def parse_feature_coords(dir, file, feature_coords, lengths):
    
    path = dir + '/' + file
    try:
        fh = open(path)
        for line in fh:
            if line.startswith('>'):
                accession = line[1:].rstrip()
            else:
                fields = line.rstrip().split('\t')
                if fields[0].startswith((">","<")):
                    fields[0] = fields[0][1:]
                # end if
                if fields[1].startswith((">","<")):
                    fields[1] = fields[1][1:]
                # end if
                if int(fields[0]) < int(fields[1]):
                    try:
                        feature_coords.append((int(fields[0]), (int(fields[1])), fields[2], fields[3], 'plus'))
                    except IndexError:
                        feature_coords.append((int(fields[0]), (int(fields[1])), fields[2], '', 'plus'))
                    # end try
                    feat_len = int(fields[1]) - int(fields[0])
                    if fields[2] in lengths[file]:
                        lengths[file][fields[2]] += feat_len
                    else:
                        lengths[file]['other'] += feat_len
                    # end if
                elif int(fields[0]) > int(fields[1]):
                    try:
                        feature_coords.append((int(fields[1]), (int(fields[0])), fields[2], fields[3], 'minus'))
                    except IndexError:
                        feature_coords.append((int(fields[1]), (int(fields[0])), fields[2], '', 'minus'))
                    # end try
                    feat_len = int(fields[0]) - int(fields[1])
                    if fields[2] in lengths[file]:
                        lengths[file][fields[2]] += feat_len
                    else:
                        lengths[file]['other'] += feat_len
                    # end if
                else:
                    continue
                # end if
            # end if
        # end for
        return(feature_coords, accession, lengths)
    except:
        stderr.write("Unable to find %s file..." % file)
        sys.exit(0)
# end def

def intersection(g4_list, feature_coords, o, counters, default_counters, vir):

    for quadruplex in g4_list:
        switch = 0
        overlap_categories = set()
        o.write("\n%s\t%s\t%s\t%s\t%s" % (quadruplex[0], quadruplex[1], quadruplex[2], quadruplex[3], quadruplex[4]))
        for feat in feature_coords:
            if (int(quadruplex[0]) >= int(feat[0]) and int(quadruplex[0]) <= int(feat[1])) | (int(quadruplex[1]) >= int(feat[0]) and int(quadruplex[1]) <= int(feat[1])):
                if switch == 0:
                    o.write("\t%s\t%s\t%s\t%s\t%s" % (feat[2], feat[0], feat[1], feat[3], feat[4]))
                else:
                    o.write("\n%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s" % (quadruplex[0], quadruplex[1], quadruplex[2], quadruplex[3], quadruplex[4], feat[2], feat[0], feat[1], feat[3], feat[4]))

                if feat[2] in default_counters:
                    overlap_categories.add(feat[2])
                elif not feat[2]:
                    overlap_categories.add('not_annotated')
                else:
                    overlap_categories.add('other')
                switch += 1
                # end if
            # end if
        # end for
        for categ in overlap_categories:
            counters[vir][categ] += 1
        # end for
    # end for
    return(counters)
# end def


def main(args): # use as args['argument_name']

    g4_coords = {}
    counters = {}
    lengths = {}
    feat_numbers = {}
    default_counters = {'repeat_region':0, 'CDS':0, '3-utr':0, '5-utr':0, '5-promoter':0, '3-promoter':0, '3-utr_or_untranscribed':0, '5-utr_or_untranscribed':0, 'ncRNA':0, 'other':0, 'not_annotated':0}
    default_lengths = {'repeat_region':0, 'CDS':0, '3-utr':0, '5-utr':0, '5-promoter':0, '3-promoter':0, '3-utr_or_untranscribed':0, '5-utr_or_untranscribed':0, 'ncRNA':0, 'other':0}
    (g4_coords, feat_numbers) = parse_g4_coords(args['g4'], g4_coords, args['number_features'], feat_numbers, args['integer'])
    o = open(args['out'], 'w')
    o.write("#Start\tEnd\t%conserved\tN seq with G4\tTotal sequences\tFeature name\tFeature start\tFeature end\tFeature description\tFeature strand")
    for virus in g4_coords:
        if not g4_coords[virus]:
            stderr.write("No coords for virus %s" % virus)
            sys.exit()
        # end if
        feature_coords = []
        counters[virus] = default_counters.copy()
        lengths[virus] = default_lengths.copy()
        (feature_coords, accession, lengths) = parse_feature_coords(args['features'], virus, feature_coords, lengths)
        o.write("\n\n>%s\t%s" % (virus, accession))
        counters = intersection(g4_coords[virus],feature_coords, o, counters, default_counters, virus)
    # end for
    o.close()

    stats = open(args['output_stats'], 'w')
    stats.write("Virus\tG4_on_repeat_region\tTotal_repeat_region\tTotal_repeat_len\tG4_on_CDS\tTotal_CDS\tTotal_CDS_len\tG4_on_5-promoter\tTotal_5-promoter\tTotal_5-promoter_len\tG4_on_3-promoter\tTotal_3-promoter\tTotal_3-promoter-len\tG4_on_5-utr\tTotal_5-utr\tTotal_5-utr_len\tG4_on_3-utr\tTotal_3-utr\tTotal_3-utr_len\tG4_on_5-utr_or_untranscribed\tTotal_5-utr_or_untranscribed\tTotal_5-utr_or_untranscribed_len\tG4_on_3-utr_or_untranscribed\tTotal_3-utr_or_untranscribed\tTotal_3-utr_or_untranscribed_len\tG4_on_ncRNA\tTotal_ncRNA\tTotal_ncRNA_len\tG4_on_other\tTotal_other\tTotal_other_len\t#5-utr_or_untranscribed\t#3-utr_or_untranscribed\t#5-utr_or_untranscribed_TOT\t#3-utr_or_untranscribed_TOT\t5-utr_or_untranscribed_len\t3-utr_or_untranscribed_len\tCDS_NORM\tREPEAT_NORM\t5-UTR_OR_UNTRANSCRIBED_NORM\t3-UTR_OR_UNTRANSCRIBED_NORM\tncRNA_NORM\tOTHER_NORM\n")

    for virus in counters:

        try:
            coding = float(counters[virus]['CDS'])/lengths[virus]['CDS']*1000
        except ZeroDivisionError:
            coding = 0
        try:
            rep = float(counters[virus]['repeat_region'])/lengths[virus]['repeat_region']*1000
        except ZeroDivisionError:
            rep = 0
        try:
            fiveutr = float(counters[virus]['5-promoter']+counters[virus]['5-utr']+counters[virus]['5-utr_or_untranscribed'])/(lengths[virus]['5-promoter']+lengths[virus]['5-utr']+lengths[virus]['5-utr_or_untranscribed'])*1000
        except ZeroDivisionError:
            fiveutr = 0
        try:
            threeutr = float(counters[virus]['3-promoter']+counters[virus]['3-utr']+counters[virus]['3-utr_or_untranscribed'])/(lengths[virus]['3-promoter']+lengths[virus]['3-utr']+lengths[virus]['3-utr_or_untranscribed'])*1000
        except ZeroDivisionError:
            threeutr = 0
        try:
            ncrna=float(counters[virus]['ncRNA'])/lengths[virus]['ncRNA']*1000
        except ZeroDivisionError:
            ncrna = 0
        try:
            other = float(counters[virus]['other'])/lengths[virus]['other']*1000
        except ZeroDivisionError:
            other = 0

        stats.write("%s\t%d\t%s\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%f\t%f\t%f\t%f\t%f\t%f\n" % (virus, counters[virus]['repeat_region'], feat_numbers[virus]['repeat_region'], lengths[virus]['repeat_region'], counters[virus]['CDS'], feat_numbers[virus]['CDS'], lengths[virus]['CDS'], counters[virus]['5-promoter'], feat_numbers[virus]['5-promoter'], lengths[virus]['5-promoter'], counters[virus]['3-promoter'], feat_numbers[virus]['3-promoter'], lengths[virus]['3-promoter'], counters[virus]['5-utr'], feat_numbers[virus]['5-utr'], lengths[virus]['5-utr'], counters[virus]['3-utr'], feat_numbers[virus]['3-utr'], lengths[virus]['3-utr'], counters[virus]['5-utr_or_untranscribed'], feat_numbers[virus]['5-utr_or_untranscribed'], lengths[virus]['5-utr_or_untranscribed'], counters[virus]['3-utr_or_untranscribed'], feat_numbers[virus]['3-utr_or_untranscribed'], lengths[virus]['3-utr_or_untranscribed'], counters[virus]['ncRNA'], feat_numbers[virus]['ncRNA'], lengths[virus]['ncRNA'], counters[virus]['other'], feat_numbers[virus]['other'], lengths[virus]['other'], (counters[virus]['5-promoter']+counters[virus]['5-utr']+counters[virus]['5-utr_or_untranscribed']), (counters[virus]['3-promoter']+counters[virus]['3-utr']+counters[virus]['3-utr_or_untranscribed']), (feat_numbers[virus]['5-promoter']+feat_numbers[virus]['5-utr']+feat_numbers[virus]['5-utr_or_untranscribed']), (feat_numbers[virus]['3-promoter']+feat_numbers[virus]['3-utr']+feat_numbers[virus]['3-utr_or_untranscribed']), (lengths[virus]['5-promoter']+lengths[virus]['5-utr']+lengths[virus]['5-utr_or_untranscribed']), (lengths[virus]['3-promoter']+lengths[virus]['3-utr']+lengths[virus]['3-utr_or_untranscribed']), coding, rep, fiveutr, threeutr, ncrna, other))
    # end for
    stats.close()
# end def main


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Description of your program')
    parser.add_argument('g4', metavar='INPUT_FILE',  help='g4 coords referred to the reference sequence')
    parser.add_argument('number_features', metavar='INPUT_DIR',  help='Feature_numbers directory')
    parser.add_argument('features', metavar='INPUT_DIR', help='features coord directory, extract_all_coordinates_from_feat_table.pl output')
    parser.add_argument('out', metavar='OUTPUT_FILE', help='output file')
    parser.add_argument('output_stats', metavar='OUTPUT_FILE_STATS', help='output file of statistics')
    parser.add_argument('-i', '--integer', help='minimum percentage of required g4 conservation', type=int, metavar='INT_VALUE', required=True)
    args = vars(parser.parse_args())

    main(args)

# end if
