########################################################################
#
# File Name:   Exslt.py
#
# Docs:        http://docs.4suite.org/Xslt/Exslt.py.html
#
"""
EXSLT 2.0 extension functions and elements.  See http://www.jenitennison.com/xslt/exslt/common/
WWW: http://4suite.org/XSLT        e-mail: support@4suite.org

Copyright (c) 2001 Fourthought Inc, USA.   All Rights Reserved.
See  http://4suite.org/COPYRIGHT  for license and copyright information
"""

import os, re, string, urllib
from xml.dom import Node
from xml.xpath import Util, Compile, CoreFunctions
from xml.xpath import BuiltInExtFunctions, Conversions
from xml.xslt import BuiltInExtElements
from xml.utils import boolean
import xml
from xml.xslt import XSL_NAMESPACE, XsltElement
from xml.xslt import OutputParameters, TextWriter


EXSL_COMMON_NS = "http://exslt.org/common"
EXSL_SETS_NS = "http://exslt.org/sets"
EXSL_MATH_NS = "http://exslt.org/math"
EXSL_FUNCTIONS_NS = 'http://exslt.org/functions'

import types
try:
    g_stringTypes = [types.StringType, types.UnicodeType]
except AttributeError:
    g_stringTypes = [types.StringType]


def NodeSet(context, rtf):
    """
    The purpose of the exsl:node-set function is to convert a result tree
    fragment into a node set. If the argument is a node set already, it is
    simply returned as is. It is an error if the argument to exsl:node-set
    is not a node set or a result tree fragment.
    """
    if type(rtf) == type([]):
        return rtf
    if hasattr(rtf,'nodeType') and rtf.nodeType == Node.DOCUMENT_NODE:
        return list(rtf.childNodes)
    raise Exception('Argument to exsl:node-set must be a node set or result tree fragment, got: %s'%(repr(rtf)))


def ObjectType(context, obj):
    """
    The exsl:object-type function returns a string giving the type of the
    object passed as the argument. The possible object types are: 'string',
    'number', 'boolean', 'node-set' or 'RTF'.
    """
    if type(obj) == type([]):
        return 'node-set'
    elif type(obj) in g_stringTypes:
        return 'string'
    elif type(obj) in [type(1), type(2.3), type(4L)]:
        return 'number'
    elif hasattr(obj, 'nodeType') and obj.nodeType == Node.DOCUMENT_NODE:
        return 'RTF'
    elif obj in [boolean.true, boolean.false]:
        return 'boolean'
    raise Exception('Unknown object type of: %s'%(repr(obj)))


def Difference(context, ns1, ns2):
    """
    The set:difference function returns the difference between two node
    sets - those nodes that are in the node set passed as the first argument
    that are not in the node set passed as the second argument.
    """
    if type(ns1) != type([]) != type(ns2):
        raise Exception('Both arguments to set:difference must be node sets')
    result = filter(lambda x, o=ns2: x not in o, ns1)
    return result
    

def HasSameNode(context, ns1, ns2):
    """
    The set:has-same-node function returns true if the node set passed as the first argument shares any nodes with the node set passed as the
    second argument. If there are no nodes that are in both node sets, then it returns false. 
    """
    if type(ns1) != type([]) != type(ns2):
        raise Exception('Both arguments to set:has-same-node must be node sets')
    d = Difference(context, ns1, ns2)
    return len(d) != len(ns1) and boolean.true or boolean.false
    

def Intersection(context, ns1, ns2):
    """
    The set:intersection function returns a node set comprising the nodes that are within both the node sets passed as arguments to it. 
    """
    if type(ns1) != type([]) != type(ns2):
        raise Exception('Both arguments to set:intersection must be node sets')
    result = filter(lambda x, o=ns2: x in o, ns1)
    return result


def NumCmp(a, b):
    a = Conversions.NumberValue(a)
    b = Conversions.NumberValue(b)
    return cmp(a, b)


#Note for the following four functions:
#The 'value' of a node is calculated by evaluating the expression held
#in the string passed as the second argument with the current node equal
#to the node whose value is being calculated. 


def Distinct(context, ns, st):
    """
    The set:distinct function returns the nodes within the node set passed
    as the first argument that have different values.
    """
    if type(ns) != type([]) and (st and type(st) not in g_stringTypes):
        raise Exception('set:distinct takes a node set and an optional string')
    #st = Conversions.StringValue(st)
    expr = Compile(st)
    orig_state = context.copyNodePosSize()
    nodes = {}
    for node in ns:
        context.setNodePosSize((node, 1, 1))
        nodes[expr.evaluate(context)] = node
    context.setNodePosSize(orig_state)
    return nodes.values()


def Leading(context, ns, st):
    """
    The set:leading function returns the nodes in the node set passed as the first argument that precede, in document order, the first node in the
node set whose value evaluates to true.
    """
    if type(ns) != type([]) and (st and type(st) not in g_stringTypes):
        raise Exception('set:leading takes a node set and an optional string')
    if not ns: return ns
    expr = Compile(st)
    orig_state = context.copyNodePosSize()
    ctr = 0
    for node in Util.SortDocOrder(ns):
        context.setNodePosSize((node, 1, 1))
        val = Conversions.BooleanValue(expr.evaluate(context))
        if val:
            break
        ctr = ctr + 1
    return ns[:ctr]
    

def Trailing(context, ns, st):
    """
    The set:trailing function returns the nodes in the node set passed
    as the first argument that is itself or that follow, in document order,
    the first node in the node set whose value is true.
    """
    if type(ns) != type([]) and (st and type(st) not in g_stringTypes):
        raise Exception('set:trailing takes a node set and an optional string')
    if not ns: return ns
    expr = Compile(st)
    orig_state = context.copyNodePosSize()
    ctr = 0
    for node in Util.SortDocOrder(ns):
        context.setNodePosSize((node, 1, 1))
        val = Conversions.BooleanValue(expr.evaluate(context))
        if val:
            break
        ctr = ctr + 1
    return ns[ctr+1:]
    

def Exists(context, ns, st):
    """
    The set:exists function returns true if the value of any of the nodes
    in the node set passed as the first argument is true.
    """
    if type(ns) != type([]) and (st and type(st) not in g_stringTypes):
        raise Exception('set:exists takes a node set and an optional string')
    if not ns: return boolean.false
    expr = Compile(st)
    orig_state = context.copyNodePosSize()
    for node in ns:
        context.setNodePosSize((node, 1, 1))
        val = Conversions.BooleanValue(expr.evaluate(context))
        context.setNodePosSize(orig_state)
        if val:
            return boolean.true
    return boolean.false


def ForAll(context, ns, st):
    """
    The set:for-all function returns true if the value of all the nodes
    in the node set passed as the first argument is true.
    """
    if type(ns) != type([]) and (st and type(st) not in g_stringTypes):
        raise Exception('set:for-all takes a node set and an optional string')
    if not ns: return boolean.false
    expr = Compile(st)
    orig_state = context.copyNodePosSize()
    for node in ns:
        context.setNodePosSize((node, 1, 1))
        val = Conversions.BooleanValue(expr.evaluate(context))
        context.setNodePosSize(orig_state)
        if not val:
            return boolean.false
    return boolean.true
    

def Max(context, ns, st):
    """
    The num:max function returns the maximum value of the nodes in the
    node set passed as the first argument.
    """
    if type(ns) != type([]) and (st and type(st) not in g_stringTypes):
        raise Exception('num:max takes a node set and an optional string')
    if not ns: return boolean.false
    expr = Compile(st)
    results = []
    orig_state = context.copyNodePosSize()
    for node in ns:
        context.setNodePosSize((node, 1, 1))
        results.append(expr.evaluate(context))
    context.setNodePosSize(orig_state)
    results.sort(NumCmp)
    return results[-1]


def Min(context, ns, st):
    """
    The num:min function returns the minimum value of the nodes in the
    node set passed as the first argument.
    """
    if type(ns) != type([]) and (st and type(st) not in g_stringTypes):
        raise Exception('num:min takes a node set and an optional string')
    if not ns: return boolean.false
    expr = Compile(st)
    results = []
    orig_state = context.copyNodePosSize()
    for node in ns:
        context.setNodePosSize((node, 1, 1))
        results.append(expr.evaluate(context))
    context.setNodePosSize(orig_state)
    results.sort(NumCmp)
    return results[0]


def Highest(context, ns, st):
    """
    The num:highest function returns the nodes in the node set passed as
    the first argument with the highest value.
    """
    if type(ns) != type([]) and (st and type(st) not in g_stringTypes):
        raise Exception('num:highest takes a node set and an optional string')
    if not ns: return ns
    expr = Compile(st)
    nodes = {}
    orig_state = context.copyNodePosSize()
    for node in ns:
        context.setNodePosSize((node, 1, 1))
        val = expr.evaluate(context)
        if not nodes.has_key(val):
            nodes[val] = []
        nodes[val].append(node)
    context.setNodePosSize(orig_state)
    skeys = nodes.keys()
    skeys.sort(NumCmp)
    return nodes[skeys[-1]]


def Lowest(context, ns, st):
    """
    The num:lowest function returns the nodes in the node set passed as
    the first argument with the lowest value.
    """
    if type(ns) != type([]) and (st and type(st) not in g_stringTypes):
        raise Exception('num:lowest takes a node set and an optional string')
    if not ns: return ns
    expr = Compile(st)
    nodes = {}
    orig_state = context.copyNodePosSize()
    for node in ns:
        context.setNodePosSize((node, 1, 1))
        val = expr.evaluate(context)
        if not nodes.has_key(val):
            nodes[val] = []
        nodes[val].append(node)
    context.setNodePosSize(orig_state)
    skeys = nodes.keys()
    skeys.sort(NumCmp)
    return nodes[skeys[0]]

    
def Sum(context, ns, st):
    """
    The num:sum function returns the sum of the values of the nodes in
    the node set passed as the first argument.
    """
    if type(ns) != type([]) and (st and type(st) not in g_stringTypes):
        raise Exception('num:sum takes a node set and an optional string')
    if not ns: return 0.0
    expr = Compile(st)
    vals = []
    orig_state = context.copyNodePosSize()
    for node in ns:
        context.setNodePosSize((node, 1, 1))
        vals.append(Conversions.NumberValue(expr.evaluate(context)))
    context.setNodePosSize(orig_state)
    return reduce(lambda x, y: x + y, vals, 0.0)


from xml import xpath

class FunctionElement(XsltElement):
    #definedFunctionBodies = {}
    legalAttrs = ('name', )

    def __init__(self, doc, uri=EXSL_FUNCTIONS_NS, localName='function', prefix='func', baseUri=''):
        XsltElement.__init__(self, doc, uri, localName, prefix, baseUri)
        self.setAttributeNS(XSL_NAMESPACE, 'extension-element-prefixes', prefix)
        return

    def setup(self):
        self._nss = xml.dom.ext.GetAllNs(self)
        name_attr = self.getAttributeNS('', 'name')
        split_name = Util.ExpandQName(
            name_attr,
            namespaces=self._nss
            )
        self._name = split_name
        self._params = []
        self._elements = []
        param_allowed = 1
        for child in self.childNodes:
            if child.namespaceURI == XSL_NAMESPACE:
                if child.localName == 'param':
                    if not param_allowed:
                        raise XsltException(Error.ILLEGAL_PARAM)
                    self._params.append(child)
                else:
                    param_allowed = 0
                    self._elements.append(child)
            else:
                param_allowed = 0
                self._elements.append(child)
        #FIXME: deal with more than 26 params
        arg_name_list = map(lambda x: 'a'+x, range(len(self._params[:26])))
        addl_arg_list = string.join(arg_name_list, ', ')
        newf = eval("lambda con, %se=self, *args: e.invoke(con, con.processor, args)"%(addl_arg_list))
        #print "lambda con, %se=self, *args: e.invoke(con, con.processor, args)"%(addl_arg_list)
        xpath.g_extFunctions[self._name] = newf
        return

    def invoke(self, context, processor, args=None):
        #print context, processor, args
        self._result = ''

        #NOTE Don't reset the context
        context.setNamespaces(self._nss)

        origVars = context.varBindings.copy()

        # Set the parameter list
        counter = 0
        for param in self._params:
            if counter < len(args):
                context.varBindings[param._name] = args[counter]
            else:
                #default
                context = param.instantiate(context, processor)[0]
            counter = counter + 1

        for child in self._elements:
           if (child.namespaceURI, child.localName) == (EXSL_FUNCTIONS_NS, 'result'):
               (context, result) = child.instantiate(context, processor)
           else:
               context = child.instantiate(context, processor)[0]

        context.varBindings = origVars
        return self._result

    def __getinitargs__(self):
        return (None, self.namespaceURI, self.localName, self.prefix,
                self.baseUri)

    def __getstate__(self):
         base_state = XsltElement.__getstate__(self)
         new_state = (base_state, self._nss, self._name, self._params,
                      self._elements)
         return new_state

    def __setstate__(self, state):
        XsltElement.__setstate__(self, state[0])
        self._nss = state[1]
        self._name = state[2]
        self._params = state[3]
        self._elements = state[4]
        return

class ResultElement(XsltElement):
    legalAttrs = ('select', )
    
    def __init__(self, doc, uri=EXSL_FUNCTIONS_NS, localName='result', prefix='func', baseUri=''):
        XsltElement.__init__(self, doc, uri, localName, prefix, baseUri)
        return

    def setup(self):
        self._nss = xml.dom.ext.GetAllNs(self)
        name_attr = self.getAttributeNS('', 'name')
        self._name = Util.ExpandQName(name_attr, namespaces=self._nss)
        self._select = self.getAttributeNS('', 'select')
        if self._select:
            self._expr = self.parseExpression(self._select)
        else:
            self._expr = None
        node = self
        self._function = None
        while node.parentNode != node.ownerDocument:
           if (node.namespaceURI, node.localName) == (EXSL_FUNCTIONS_NS, 'function'):
               self._function = node
           node = node.parentNode
        if not self._function:
            raise Exception("An EXSLT func:result element must occur within a func:function element")
        return

    def instantiate(self, context, processor):
        origState = context.copy()
        context.setNamespaces(self._nss)
        if self._select:
            result = self._expr.evaluate(context)
        else:
            processor.pushResult()
            for child in self.childNodes:
                context = child.instantiate(context, processor)[0]
            result = processor.popResult()
            context.rtfs.append(result)

        context.set(origState)
        self._function._result = result
        return (context, )

    def __getinitargs__(self):
        return (None, self.namespaceURI, self.localName, self.prefix,
                self.baseUri)

    def __getstate__(self):
         base_state = XsltElement.__getstate__(self)
         new_state = (base_state, self._nss, self._select, self._function)
         return new_state

    def __setstate__(self, state):
        XsltElement.__setstate__(self, state[0])
        self._nss = state[1]
        self._select = state[2]
        self._function = state[3]
        return


class DocumentElement(XsltElement):
    def __init__(self, doc, uri=EXSL_COMMON_NS, localName='document',
                 prefix='common', baseUri=''):
        XsltElement.__init__(self, doc, uri, localName, prefix, baseUri)
        return

    def setup(self):
        self._nss = xml.dom.ext.GetAllNs(self)
        self._out = OutputParameters()
        self._out.avtParsePrep(self, self._nss)
        self._href = self.parseAVT(self.getAttributeNS('', 'href'))
        return

    def instantiate(self, context, processor):
        origState = context.copy()
        context.processorNss = self._nss

        self._out.avtParse(context)

        href = self._href.evaluate(context)
        try:
            f = open(href, 'w')
        except IOError:
            dir = os.path.split(href)[0]
            if not os.access(dir, os.F_OK):
                os.makedirs(dir)
                f = open(href, 'w')
            else:
                raise
        #overwrite = self._overwrite.evaluate(context)
        #if overwrite == 'yes':
        #    f = open(name, 'w')
        #else:
        #    f = open(name, 'a')
        processor.addHandler(self._out, f)
        for child in self.childNodes:
            context = child.instantiate(context, processor)[0]
        processor.removeHandler()
        f.close()

        context.set(origState)
        return (context,)

    def __getinitargs__(self):
        return (None, self.namespaceURI, self.localName, self.prefix,
                self.baseUri)

    def __getstate__(self):
         base_state = XsltElement.__getstate__(self)
         new_state = (base_state, self._nss, self._href, self._out)
         return new_state

    def __setstate__(self, state):
        XsltElement.__setstate__(self, state[0])
        self._nss = state[1]
        self._href = state[2]
        self._out = state[3]
        return


ExtElements = {
    (EXSL_FUNCTIONS_NS, 'function'): FunctionElement,
    (EXSL_FUNCTIONS_NS, 'result'): ResultElement,
    (EXSL_COMMON_NS, 'document'): DocumentElement,
    }

CommonFunctions = {
    (EXSL_COMMON_NS, 'node-set'): NodeSet,
    (EXSL_COMMON_NS, 'object-type'): ObjectType,
    }

SetFunctions = {
    (EXSL_SETS_NS, 'difference'): Difference,
    (EXSL_SETS_NS, 'intersection'): Intersection,
    (EXSL_SETS_NS, 'distinct'): Distinct,
    (EXSL_SETS_NS, 'has-same-node'): HasSameNode,
    (EXSL_SETS_NS, 'leading'): Leading,
    (EXSL_SETS_NS, 'trailing'): Trailing,
    #(EXSL_SETS_NS, 'exists'): Exists,
    #(EXSL_SETS_NS, 'for-all'): ForAll,
    }

MathFunctions = {
    (EXSL_MATH_NS, 'max'): Max,
    (EXSL_MATH_NS, 'min'): Min,
    (EXSL_MATH_NS, 'highest'): Highest,
    (EXSL_MATH_NS, 'lowest'): Lowest,
#    (EXSL_MATH_NS, 'sum'): Sum,
}

ExtFunctions = CommonFunctions
ExtFunctions.update(SetFunctions)
ExtFunctions.update(MathFunctions)

