#! /usr/bin/env python

# -----------------------------------------------------------------------
# Convert a macro prototype to a LaTeX \newcommand
# By Scott Pakin <scott+nc@pakin.org>
# -----------------------------------------------------------------------
# Copyright (C) 2008 Scott Pakin, scott+nc@pakin.org
# 
# This package may be distributed and/or modified under the conditions
# of the LaTeX Project Public License, either version 1.3c of this
# license or (at your option) any later version.  The latest version of
# this license is in:
# 
#     http://www.latex-project.org/lppl.txt
# 
# and version 1.3c or later is part of all distributions of LaTeX version
# 2006/05/20 or later.
# -----------------------------------------------------------------------

from spark import GenericScanner, GenericParser, GenericASTTraversal
import re
import copy

class Token:
    "Represent a single lexed token."

    def __init__ (self, type, charOffset, attr=None):
        self.type = type
        self.attr = attr
        self.charOffset = charOffset

    def __cmp__ (self, o):
        return cmp (self.type, o)


class AST:
    "Represent an abstract syntax tree."

    def __init__ (self, type, charOffset, attr=None, kids=[]):
        self.type = type
        self.charOffset = charOffset
        self.attr = attr
        self.kids = kids

    def __getitem__ (self, child):
        return self.kids[child]

    def __len__ (self):
        return len (self.kids)


class CmdScanner (GenericScanner):
    "Defines a lexer for macro prototypes."

    def __init__ (self):
        GenericScanner.__init__ (self)
        self.charOffset = 0

    def tokenize (self, input):
        self.rv = []
        GenericScanner.tokenize (self, input)
        return self.rv

    def t_whitespace (self, whiteSpace):
        r' [\s\r\n]+ '
        self.charOffset = self.charOffset + len (whiteSpace)

    def t_command (self, cmd):
        r' MACRO '
        self.rv.append (Token (type='command',
                               attr=cmd,
                               charOffset=self.charOffset))
        self.charOffset = self.charOffset + len (cmd)

    def t_argument_type (self, arg):
        r' OPT '
        self.rv.append (Token (type='argtype',
                               attr=arg,
                               charOffset=self.charOffset))
        self.charOffset = self.charOffset + len (arg)

    def t_argument (self, arg):
        r' \#\d '
        self.rv.append (Token (type='argument',
                               attr=arg,
                               charOffset=self.charOffset))
        self.charOffset = self.charOffset + len (arg)

    def t_equal (self, equal):
        r' = '
        self.rv.append (Token (type=equal,
                               attr=equal,
                               charOffset=self.charOffset))
        self.charOffset = self.charOffset + len (equal)

    def t_quoted (self, quoted):
        r' \{[^\}]*\} '
        self.rv.append (Token (type='quoted',
                               attr=quoted,
                               charOffset=self.charOffset))
        self.charOffset = self.charOffset + len (quoted)

    def t_identifier (self, ident):
        r' [A-Za-z]+ '
        self.rv.append (Token (type='ident',
                               attr=ident,
                               charOffset=self.charOffset))
        self.charOffset = self.charOffset + len (ident)

    def t_delimiter (self, delim):
        r' [()\[\]] '
        self.rv.append (Token (type='delim',
                               attr=delim,
                               charOffset=self.charOffset))
        self.charOffset = self.charOffset + len (delim)

    def t_other (self, other):
        r' [^()\[\]\{\}\#\s\r\n]+ '
        self.rv.append (Token (type='other',
                               attr=other,
                               charOffset=self.charOffset))
        self.charOffset = self.charOffset + len (other)


class CmdParser (GenericParser):
    "Defines a parser for macro prototypes."

    parseError = "parseError"         # Exception

    def __init__ (self, start='decl'):
        GenericParser.__init__ (self, start)

    def error (self, token):
        raise self.parseError, \
              ('"%s" was unexpected' % token.attr, 1+token.charOffset)

    def p_optarg (self, args):
        ' optarg ::= argtype delim defvals delim '
        return AST (type='optarg',
                    charOffset=args[0].charOffset,
                    attr=[args[1].attr, args[3].attr],
                    kids=args[2])

    def p_rawtext (self, args):
        ' rawtext ::= other '
        return AST (type='rawtext',
                    charOffset=args[0].charOffset,
                    attr=args[0].attr)

    def p_defval (self, args):
        ' defval ::= argument = quoted '
        return AST (type='defval',
                    charOffset=args[0].charOffset,
                    attr=[args[0].attr, args[2].attr])

    def p_defvals_1 (self, args):
        ' defvals ::= defval '
        return AST (type='defvals',
                    charOffset=args[0].charOffset,
                    kids=args)

    def p_defvals_2 (self, args):
        '''
            defvals ::= defvals rawtext defval
            defvals ::= defvals ident defval
            defvals ::= defvals quoted defval
        '''
        return AST (type='defvals',
                    charOffset=args[0].charOffset,
                    attr=args[1],
                    kids=[args[0],args[2]])

    # Top-level macro argument
    def p_arg_1 (self, args):
        '''
            arg ::= rawtext
            arg ::= quoted
            arg ::= argument
        '''
        return AST (type='arg',
                    charOffset=args[0].charOffset,
                    attr=[args[0].type]+[args[0].attr])

    def p_arg_2 (self, args):
        ' arg ::= optarg '
        return AST (type='arg',
                    charOffset=args[0].charOffset,
                    attr=[args[0].type]+[args[0].attr],
                    kids=args[0].kids)

    def p_arglist_1 (self, args):
        '''
            arglist ::= arg
            arglist ::= arg arglist
        '''
        return AST (type='arglist',
                    charOffset=args[0].charOffset,
                    kids=args)

    def p_arglist_2 (self, args):
        ' arglist ::= '
        return AST (type='arglist', charOffset=0)

    def p_decl (self, args):
        ' decl ::= command ident arglist '
        return AST (type='decl',
                    charOffset=args[0].charOffset,
                    attr=args[:2],
                    kids=args[2])


class CheckAST (GenericASTTraversal):
    "Performs semantic analysis on an AST."

    semanticError = "semanticError"     # Exception

    _paramPat = re.compile (r"(#\d)")

    def __init__ (self, ast):
        GenericASTTraversal.__init__ (self, ast)
        self._paramList = [0]      # Sentinel, to avoid special first case
        self.postorder()

    def _appendParam (self, paramStr, charOffset):
        """
            Checks that a parameter number is correct and either
            appends it to _paramList if it is or raises an exception
            if it's not.
        """
        paramNum = int (paramStr)
        nextParam = self._paramList[-1] + 1
        if paramNum != nextParam:
            raise self.semanticError, \
                  ('Saw parameter #%s when parameter #%s was expected' %
                   (paramStr, nextParam),
                   1+charOffset)
        self._paramList.append (paramNum)

    def n_defval (self, node):
        for replArg in map (lambda s: int(s[1]), self._paramPat.findall(node.attr[1])):
            if replArg > self._paramList[-1]:
                # The offset will be incorrect if there are spaces around
                # the "=" sign in the "parameter=quoted" expression.
                raise self.semanticError, \
                      ('Parameter #%d was used before being defined' % replArg,
                       1 + node.charOffset + len(node.attr[1]))
        self._appendParam (node.attr[0][1], node.charOffset)

    def n_arg (self, node):
        if node.attr[0] == "quoted":
            badParam = self._paramPat.search (node.attr[1])
            if badParam:
                raise self.semanticError, \
                      ("The formal parameter `%s' should not be quoted" %
                       badParam.group(1),
                       1 + node.charOffset + badParam.start(1))
        elif node.attr[0] == "argument":
            self._appendParam (node.attr[1][1], node.charOffset)


class GenerateLaTeX (GenericASTTraversal):
    "Generates a LaTeX macro from a given AST."

    # Map an opening delimiter to a closing delimiter.
    closingDelim = {"[" : "]", "(" : ")"}

    def __init__ (self, ast):
        GenericASTTraversal.__init__ (self, ast)
        self._funcInfo = []
        self._argTypes = []
        self.preorder()
        self.outputLaTeX()

    def outputLaTeX (self):
        "Outputs a block of LaTeX code that matches the original prototype."

        # Determine the number of arguments that LaTeX's \newcommand
        # can handle by itself.
        argTypes = self._argTypes
        newcommandArgs = 0
        defaultArgs = 0
        def simpleNextArg (arg):
            "Says whether a given argument can follow \newcommand arguments."
            return arg=="argument" or \
                   arg=="optarg" or \
                   arg=="optarg_simple"
        trivial = not argTypes
        if argTypes and argTypes[0]=="optarg_simple" and \
           (len(argTypes)==2 or simpleNextArg(argTypes[2])):
            newcommandArgs = 1
            defaultArgs = 1
            argTypes.pop (0)
            simple_optarg = argTypes.pop (0)
        while argTypes and argTypes[0]=="argument" and \
           (len(argTypes)==1 or simpleNextArg(argTypes[1])):
            newcommandArgs = newcommandArgs + 1
            argTypes.pop (0)

        # Determine if we need to insert \makeatletter and \makeatother.
        funcInfo = self._funcInfo[defaultArgs:]
        atSigns = len (filter (lambda f: f and "@" in f,
                               map (lambda f: f["body"], funcInfo)))
        if atSigns > 0:
            print "\\makeatletter"

        # Peel off the first newcommandArgs entries, and let LaTeX
        # deal with those within the \newcommand.
        newcommand = "\\newcommand{\\%s}" % self._funcInfo[0]["name"]
        if newcommandArgs > 0:
            newcommand = newcommand + "[%d]" % newcommandArgs
            if defaultArgs == 1:
                newcommand = newcommand + "[%s]" % simple_optarg
            print "%s{%%" % newcommand
            suppressDef = 1
        elif trivial:
            print "%s{%%" % newcommand
            suppressDef = 1
        else:
            suppressDef = 0

        # Handle the remaining arguments the hard way.
        funcInfo = self._funcInfo[defaultArgs:]
        for func in funcInfo:
            if suppressDef:
                suppressDef = 0
            else:
                print "\\def\\%s%s{%%" % (func["name"], func["args"])
            if func["body"]:
                print "  %s%%%s" % (func["body"], func["close"])
            else:
                print "  % Put your code here."
            print "}"
        if atSigns > 0:
            print "\\makeatother"

    def _nextMacroName (self):
        """
            Returns a LaTeX macro name based on the size and first
            element of _funcInfo.
        """
        return self._funcInfo[0]["name"] + "@" + \
               ["i", "ii", "iii", "iv", "v",
                "vi", "vii", "viii", "ix", "x"][len (self._funcInfo) - 1]

    class DefvalsArguments (GenericASTTraversal):
        """
            Returns the formal or actual arguments of a defvals.

            In either case, quoted and rawtext parameters are included, too.
        """

        def __init__ (self, ast):
            GenericASTTraversal.__init__ (self, ast)
            self.formalsList = []
            self.actualsList = []
            self.postorder()

        def n_defval (self, node):
            self.formalsList.append (node.attr[0])
            self.actualsList.append (node.attr[1][1:-1])

        def n_defvals (self, node):
            if node.attr != None:
                penultimate = len (self.formalsList) - 1
                self.formalsList.insert (penultimate, node.attr.attr)
                self.actualsList.insert (penultimate, node.attr.attr)

    def n_arg (self, node):
        "Specify a new function based on each argument."

        # Store the type, making a special case for simple optargs.
        if node.attr[0] == "optarg":
            lDelim, rDelim = node.attr[1]
            if lDelim=="[" and rDelim=="]" and len(node.kids)==1:
                self._argTypes.extend (["optarg_simple",
                                        node.kids[0].attr[1][1:-1]])
            else:
                self._argTypes.append (node.attr[0])
        else:
            self._argTypes.append (node.attr[0])

        # Append optional arguments to _funcInfo and modify old
        # entries in _funcInfo based on required arguments.
        if node.attr[0] == "optarg":
            macroName = self._nextMacroName()
            prevInfo = self._funcInfo[-1]
            defvalsArguments = self.DefvalsArguments (node.kids)
            lDelim, rDelim = node.attr[1]
            defaultArgs = lDelim + \
                          reduce (lambda x,y: x+y,
                                  defvalsArguments.actualsList) + \
                          rDelim
            prevInfo["body"] = "\\@ifnextchar%s{\\%s%s}{\\%s%s%s}" % \
                               (lDelim, macroName, prevInfo["curlyargs"],
                                macroName, prevInfo["curlyargs"], defaultArgs)
            try:
                prevInfo["close"] = self.closingDelim[lDelim]
            except KeyError:
                prevInfo["close"] = ""
            formalArgs = prevInfo["args"] + lDelim + \
                         reduce (lambda x,y: x+y,
                                 defvalsArguments.formalsList) + \
                         rDelim
            curlyArgs = prevInfo["curlyargs"] + lDelim + \
                        reduce (lambda x,y: x+y,
                                defvalsArguments.formalsList) + \
                        rDelim
            self._funcInfo.append ({"name": macroName,
                                    "body": None,
                                    "close": prevInfo["close"],
                                    "args": formalArgs,
                                    "curlyargs": curlyArgs})
        elif node.attr[0] in ["argument", "rawtext"]:
            self._funcInfo[-1]["args"] = self._funcInfo[-1]["args"] + node.attr[1]
            self._funcInfo[-1]["curlyargs"] = self._funcInfo[-1]["curlyargs"] + "{" + node.attr[1] + "}"
        elif node.attr[0] == "quoted":
            self._funcInfo[-1]["args"] = self._funcInfo[-1]["args"] + node.attr[1][1:-1]
            self._funcInfo[-1]["curlyargs"] = self._funcInfo[-1]["curlyargs"] + node.attr[1][1:-1]

    def n_decl (self, node):
        "Prepares the top-level LaTeX macro."
        self._funcInfo = [{"name": node.attr[1].attr,
                           "args": "",
                           "curlyargs": "",
                           "close": "",
                           "body": None}]


# The buck starts here.
if __name__ == '__main__':
    import sys
    import string

    def processLine():
        "Parses the current value of oneLine."
        global oneLine
        try:
            sys.stdout.softspace = 0        # Cancel the %$#@! space.
            oneLine = string.strip (oneLine)
            if oneLine=="" or oneLine[0]=="%":
                return
            if not isStdin:
                print prompt, oneLine
            scanner = CmdScanner()
            parser = CmdParser()
            tokens = scanner.tokenize (oneLine)
            ast = parser.parse (tokens)
            CheckAST (ast)
            GenerateLaTeX (ast)
        except CmdParser.parseError, (message, pos):
            sys.stderr.write ((" "*(len(prompt)+pos)) + "^\n")
            sys.stderr.write ("%s: %s.\n" % (sys.argv[0], message))
        except CheckAST.semanticError, (message, pos):
            sys.stderr.write ((" "*(len(prompt)+pos)) + "^\n")
            sys.stderr.write ("%s: %s.\n" % (sys.argv[0], message))
        if isStdin:
            print ""

    prompt = "% Prototype:"
    if len (sys.argv) <= 1:
        isStdin = 1
        print prompt + " ",
        while 1:
            oneLine = sys.stdin.readline()
            if not oneLine:
                break
            processLine()
            print prompt + " ",
    else:
        isStdin = 0
        oneLine = string.join (sys.argv[1:])
        processLine()