#!/usr/bin/env python
#
# Copyright (C) 2011 W. Trevor King <wking@drexel.edu>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this program.  If not, see
# <http://www.gnu.org/licenses/>.

AMINO_ACIDS = {
    'Alanine':        ['Ala', 'A'],
    'Argenine':       ['Arg', 'R'],
    'Asparagine':     ['Asn', 'N'],
    'Aspartic acid':  ['Asp', 'D'],
    'Cystine':        ['Cys', 'C'],
    'Glutamic acid':  ['Glu', 'E'],
    'Glutamine':      ['Gln', 'Q'],
    'Glycine':        ['Gly', 'G'],
    'Histidine':      ['His', 'H'],
    'Isoleucine':     ['Ile', 'I'],
    'Leucine':        ['Leu', 'L'],
    'Lysine':         ['Lys', 'K'],
    'Methionine':     ['Met', 'M'],
    'Phenylalanine':  ['Phe', 'F'],
    'Proline':        ['Pro', 'P'],
    'Serine':         ['Ser', 'S'],
    'Threonine':      ['Thr', 'T'],
    'Tryptophan':     ['Trp', 'W'],
    'Tyrosine':       ['Tyr', 'Y'],
    'Valine':         ['Val', 'V'],
    }

NUCLEOTIDES = {
    'Adenine':  ['A'],
    'Cytosine': ['C'],
    'Guanine':  ['G'],
    'Thymine':  ['T'],
    'Uracil':   ['U'],
    }

NUCLEOTIDE_COMPLEMENT = {
    'Adenine': 'Cytosine',
    'Cytosine': 'Adenine',
    'Guanine': 'Thymine',
    'Thymine': 'Guanine',
    'Uracil': 'Guanine',
    }

CODE = {
    ('Adenine',  'Adenine',  'Adenine'):  'Lysine',
    ('Adenine',  'Adenine',  'Cytosine'): 'Asparagine',
    ('Adenine',  'Adenine',  'Guanine'):  'Lysine',
    ('Adenine',  'Adenine',  'Uracil'):   'Asparagine',
    ('Adenine',  'Cytosine', 'Adenine'):  'Threonine',
    ('Adenine',  'Cytosine', 'Cytosine'): 'Threonine',
    ('Adenine',  'Cytosine', 'Guanine'):  'Threonine',
    ('Adenine',  'Cytosine', 'Uracil'):   'Threonine',
    ('Adenine',  'Guanine',  'Adenine'):  'Argenine',
    ('Adenine',  'Guanine',  'Cytosine'): 'Serine',
    ('Adenine',  'Guanine',  'Guanine'):  'Argenine',
    ('Adenine',  'Guanine',  'Uracil'):   'Serine',
    ('Adenine',  'Uracil',   'Adenine'):  'Isoleucine',
    ('Adenine',  'Uracil',   'Cytosine'): 'Isoleucine',
    ('Adenine',  'Uracil',   'Guanine'):  'Methionine',
    ('Adenine',  'Uracil',   'Uracil'):   'Isoleucine',
    ('Cytosine', 'Adenine',  'Adenine'):  'Glutamine',
    ('Cytosine', 'Adenine',  'Cytosine'): 'Histidine',
    ('Cytosine', 'Adenine',  'Guanine'):  'Glutamine',
    ('Cytosine', 'Adenine',  'Uracil'):   'Histidine',
    ('Cytosine', 'Cytosine', 'Adenine'):  'Proline',
    ('Cytosine', 'Cytosine', 'Cytosine'): 'Proline',
    ('Cytosine', 'Cytosine', 'Guanine'):  'Proline',
    ('Cytosine', 'Cytosine', 'Uracil'):   'Proline',
    ('Cytosine', 'Guanine',  'Adenine'):  'Argenine',
    ('Cytosine', 'Guanine',  'Cytosine'): 'Argenine',
    ('Cytosine', 'Guanine',  'Guanine'):  'Argenine',
    ('Cytosine', 'Guanine',  'Uracil'):   'Argenine',
    ('Cytosine', 'Uracil',   'Adenine'):  'Leucine',
    ('Cytosine', 'Uracil',   'Cytosine'): 'Leucine',
    ('Cytosine', 'Uracil',   'Guanine'):  'Leucine',
    ('Cytosine', 'Uracil',   'Uracil'):   'Leucine',
    ('Guanine',  'Adenine',  'Adenine'):  'Glutamic acid',
    ('Guanine',  'Adenine',  'Cytosine'): 'Aspartic acid',
    ('Guanine',  'Adenine',  'Guanine'):  'Glutamic acid',
    ('Guanine',  'Adenine',  'Uracil'):   'Aspartic acid',
    ('Guanine',  'Cytosine', 'Adenine'):  'Alanine',
    ('Guanine',  'Cytosine', 'Cytosine'): 'Alanine',
    ('Guanine',  'Cytosine', 'Guanine'):  'Alanine',
    ('Guanine',  'Cytosine', 'Uracil'):   'Alanine',
    ('Guanine',  'Guanine',  'Adenine'):  'Glycine',
    ('Guanine',  'Guanine',  'Cytosine'): 'Glycine',
    ('Guanine',  'Guanine',  'Guanine'):  'Glycine',
    ('Guanine',  'Guanine',  'Uracil'):   'Glycine',
    ('Guanine',  'Uracil',   'Adenine'):  'Valine',
    ('Guanine',  'Uracil',   'Cytosine'): 'Valine',
    ('Guanine',  'Uracil',   'Guanine'):  'Valine',
    ('Guanine',  'Uracil',   'Uracil'):   'Valine',
    ('Uracil',   'Adenine',  'Adenine'):  'STOP (Ochre)',
    ('Uracil',   'Adenine',  'Cytosine'): 'Tyrosine',
    ('Uracil',   'Adenine',  'Guanine'):  'STOP (Amber)',
    ('Uracil',   'Adenine',  'Uracil'):   'Tyrosine',
    ('Uracil',   'Cytosine', 'Adenine'):  'Serine',
    ('Uracil',   'Cytosine', 'Cytosine'): 'Serine',
    ('Uracil',   'Cytosine', 'Guanine'):  'Serine',
    ('Uracil',   'Cytosine', 'Uracil'):   'Serine',
    ('Uracil',   'Guanine',  'Adenine'):  'STOP (Opal)',
    ('Uracil',   'Guanine',  'Cytosine'): 'Cystine',
    ('Uracil',   'Guanine',  'Guanine'):  'Tryptophan',
    ('Uracil',   'Guanine',  'Uracil'):   'Cystine',
    ('Uracil',   'Uracil',   'Adenine'):  'Leucine',
    ('Uracil',   'Uracil',   'Cytosine'): 'Phenylalanine',
    ('Uracil',   'Uracil',   'Guanine'):  'Leucine',
    ('Uracil',   'Uracil',   'Uracil'):   'Phenylalanine',
    }


_INVERSE_DICTS = {}
def unabbreviate(abbreviations, abbreviation):
    """
    >>> unabbreviate(AMINO_ACIDS, 'Ala')
    'Alanine'
    >>> unabbreviate(AMINO_ACIDS, 'A')
    'Alanine'
    >>> unabbreviate(NUCLEOTIDES, 'A')
    'Adenine'
    """
    try:
        inverse = _INVERSE_DICTS[id(abbreviations)]
    except KeyError:
        inverse = {}
        for k,abbrevs in abbreviations.items():
            for abbrev in abbrevs:
                inverse[abbrev] = k
        _INVERSE_DICTS[id(abbreviations)] = inverse
    inverse = _INVERSE_DICTS[id(abbreviations)]
    return inverse[abbreviation]

def decode_sequence(abbreviations, sequence):
    """
    >>> list(decode_sequence(NUCLEOTIDES, 'ACG TU'))
    ['Adenine', 'Cytosine', 'Guanine', 'Thymine', 'Uracil']
    """
    for x in sequence:
        if x.isspace():
            continue
        yield unabbreviate(abbreviations, x.upper())

def transcribe_to_mRNA(nucleotides):
    """
    >>> list(transcribe_to_mRNA(['Adenine', 'Cytosine', 'Guanine', 'Thymine']))
    ['Adenine', 'Cytosine', 'Guanine', 'Uracil']
    """
    for n in nucleotides:
        if n == 'Thymine':
            yield 'Uracil'
        else:
            yield n

def split_into_codons(nucleotides, length=3):
    """
    >>> sequence = 'AGC TTC ATG CGT CCG AAG CC'
    >>> nucleotides = decode_sequence(NUCLEOTIDES, sequence)
    >>> codons = split_into_codons(nucleotides)
    >>> print '\\n'.join(str(c) for c in codons)
    ('Adenine', 'Guanine', 'Cytosine')
    ('Thymine', 'Thymine', 'Cytosine')
    ('Adenine', 'Thymine', 'Guanine')
    ('Cytosine', 'Guanine', 'Thymine')
    ('Cytosine', 'Cytosine', 'Guanine')
    ('Adenine', 'Adenine', 'Guanine')
    ('Cytosine', 'Cytosine')
    """
    codon = []
    for n in nucleotides:
        codon.append(n)
        if len(codon) == length:
            yield tuple(codon)
            codon = []
    if len(codon):
        yield tuple(codon)

def translate_to_amino_acids(nucleotides):
    """
    The input sequence should be mRNA nucleotides read from 5' to 3'.

    >>> sequence = 'AUG AGC UUC AUG CGU CCG AAG'
    >>> nucleotides = decode_sequence(NUCLEOTIDES, sequence)
    >>> amino_acids = translate_to_amino_acids(nucleotides)
    >>> print '\\n'.join(amino_acids)
    Methionine
    Serine
    Phenylalanine
    Methionine
    Argenine
    Proline
    Lysine

    The the leading Methionine is also the "start codon".  There are
    other possible start codon sequences (e.g. GUG) used in
    prokaryotes such as E. coli.
    """
    for codon in split_into_codons(nucleotides):
        aa = CODE[codon]
        yield aa


if __name__ == '__main__':
    import argparse
    import sys

    p = argparse.ArgumentParser(
        description='Translate DNA/mRNA to an amino acid sequence.')
    p.add_argument(
        'sequence', metavar='ACGTU', nargs='*',
        help="Genetic sequence to translate (5' to 3')")
    p.add_argument(
        '-s', '--short', action='store_true', default=False,
        help='Use single-letter amino acid abbreviations')
    p.add_argument(
        '-m', '--match',
        help='Match a protein sequence in the source mRNA')
    p.add_argument(
        '-c', '--complement', action='store_true', default=False,
        help='Print the complement DNA amd exit')
    p.add_argument(
        '--count', action='store_true', default=False,
        help='Print amino acid, nucleotide, and codon counts and exit')
    p.add_argument(
        '--table', action='store_true', default=False,
        help='Print translation tables and exit')

    args = p.parse_args()

    if args.count:
        print('amino acids: {:d}'.format(len(AMINO_ACIDS)))
        print('nucleotides: {:d}'.format(len(NUCLEOTIDES)))
        print('codons:      {:d}'.format(len(CODE)))
        sys.exit()
    elif args.table:
        print('RNA   Amino acid')
        print('===   ==========')
        order = ('Uracil', 'Cytosine', 'Adenine', 'Guanine')
        for x_ in order:
            x = NUCLEOTIDES[x_][-1]
            for y_ in order:
                y = NUCLEOTIDES[y_][-1]
                for z_ in order:
                    z = NUCLEOTIDES[z_][-1]
                    aa_ = CODE[(x_, y_, z_)]
                    try:
                        aa = ' '.join(AMINO_ACIDS[aa_])
                    except KeyError:
                        aa = aa_
                    print('{}{}{}   {}'.format(x, y, z, aa))
        print('')
        print('AA      RNA')
        print('=====   ===')
        for aa_ in sorted(AMINO_ACIDS.keys()):
            aa = ' '.join(AMINO_ACIDS[aa_])
            codons = sorted(''.join(NUCLEOTIDES[x][-1] for x in k)
                            for k,v in CODE.items() if v == aa_)
            print('{}   {}'.format(aa, ' '.join(codons)))
        sys.exit()

    mRNA = ' '.join(args.sequence)
    if args.complement:
        for n in mRNA:
            try:
                nucleotide = unabbreviate(NUCLEOTIDES, n)
                complement = NUCLEOTIDE_COMPLEMENT[nucleotide]
                c = NUCLEOTIDES[complement][-1]
            except KeyError:
                c = n
            sys.stdout.write(c)
        sys.stdout.write('\n')
        sys.exit()

    nucleotides = decode_sequence(NUCLEOTIDES, mRNA)
    nucleotides = list(transcribe_to_mRNA(nucleotides))  # no-op on mRNA
    amino_acids = list(translate_to_amino_acids(nucleotides))

    if args.match:
        match = list(decode_sequence(AMINO_ACIDS, args.match))
        for start in range(len(amino_acids)-len(match)):
            fragment = amino_acids[start:start+len(match)]
            if fragment == match:
                start_n = start*3
                stop_n = start_n + 3*len(match)
                print('matched nucleotides {:d} through {:d}'.format(
                        start_n, stop_n-1))
                print(''.join(NUCLEOTIDES[n][-1]
                              for n in nucleotides[start_n:stop_n]))
    else:
        if args.short:
            sep = ''
            i = -1
        else:
            sep = '-'
            i = 0
        print(sep.join(AMINO_ACIDS.get(aa, ['!'])[i] for aa in amino_acids))
