# Sketch - A Python-based interactive drawing program
# Copyright (C) 1997, 1998, 1999 by Bernhard Herzog
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or (at
# your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.

###Sketch Config
#type = Import
#class_name = 'SVGLoader'
#rx_magic = '.*\\<(\\?xml|svg)'
#tk_file_type = ('Scalable Vector Graphics (SVG)', ('.svg', '.xml'))
format_name = 'SVG'
#unload = 1
###End

from types import StringType
from math import pi, tan
import os, sys
import re
from string import strip, split, atoi, lower
import string

import streamfilter

from Sketch import Document, Layer, CreatePath, ContSmooth, \
     SolidPattern, EmptyPattern, LinearGradient, RadialGradient,\
     CreateRGBColor, CreateCMYKColor, MultiGradient, \
     Trafo, Translation, Rotation, Scale, Point, Polar, \
     StandardColors, GetFont, PathText, SimpleText, const, UnionRects, \
     Bezier, Line

from Sketch.warn import INTERNAL, USER, warn_tb

from Sketch.load import GenericLoader, EmptyCompositeError

try:
    from xml.sax import saxlib, saxexts
except ImportError:
    warn_tb(USER, "The Python xml package has to be installed for the svg "
            "import filter.\n"
            "See the README for more information.")
    raise


factors = {'pt': 1.0, 'px': 1.0, 'in': 72.0,
           'cm': 72.0 / 2.54, 'mm': 7.20 / 2.54}

degrees = pi / 180.0

def length(str):
    str = strip(str)
    factor = factors.get(str[-2:])
    if factor is not None:
        str = str[:-2]
    elif str[-1] == '%':
        str = str[:-1]
        factor = 0.01
    else:
        factor = 1.0
    return float(str) * factor



def csscolor(str):
    str = strip(str)
    if str[0] == '#':
        if len(str) == 7:
            r = atoi(str[1:3], 16) / 255.0
            g = atoi(str[3:5], 16) / 255.0
            b = atoi(str[5:7], 16) / 255.0
        elif len(str) == 4:
            r = atoi(str[1], 16) / 16.0
            g = atoi(str[2], 16) / 16.0
            b = atoi(str[3], 16) / 16.0
        color = CreateRGBColor(r, g, b)
    elif namedcolors.has_key(str):
        color = namedcolors[str]
    else:
        color = StandardColors.black
    return color


namedcolors = {"black": csscolor("#000000"),
               "silver": csscolor("#c0c0c0"),
               "gray": csscolor("#808080"),
               "white": csscolor("#FFFFFF"),
               "maroon": csscolor("#800000"),
               "red": csscolor("#FF0000"),
               "purple": csscolor("#800080"),
               "fuchsia": csscolor("#FF00FF"),
               "green": csscolor("#008000"),
               "lime": csscolor("#00FF00"),
               "olive": csscolor("#808000"),
               "yellow": csscolor("#FFFF00"),
               "navy": csscolor("#000080"),
               "blue": csscolor("#0000FF"),
               "teal": csscolor("#008080"),
               "aqua": csscolor("#00FFFF")}

join = {'miter': const.JoinMiter,
        'round': const.JoinRound,
        'bevel': const.JoinBevel}
cap = {'butt': const.CapButt,
       'round': const.CapRound,
       'square': const.CapProjecting}

commatospace = string.maketrans(',', ' ')

rx_command = re.compile(r'[a-df-zA-DF-Z]((\s*[-0-9.e]+)*)\s*')
rx_trafo = re.compile(r'\s*([a-zA-Z]+)\(((\s*[-0-9.e]+,?)*)\s*\)')


class SVGHandler(saxlib.HandlerBase):

    dispatch_start = {'svg': 'initsvg',
                      'g': 'begin_group',
                      'circle': 'circle',
                      'ellipse': 'ellipse',
                      'rect': 'rect',
                      'polyline': 'polyline',
                      'polygon': 'polygon',
                      'path':   'begin_path',
                      'data':   'data',
                      'use':   'use',
                      'defs':   'begin_defs',
                      }
    dispatch_end = {'g': 'end_group',
                    'path': 'end_path',
                    'defs': 'end_defs',
                    }
    
    def __init__(self, loader):
        self.loader = loader
        self.trafo = self.basetrafo = Trafo()
        self.state_stack = ()
        self.style = loader.style.Copy()
        self.style.line_pattern = EmptyPattern
        self.style.fill_pattern = EmptyPattern
        #SolidPattern(StandardColors.black)
        self.named_objects = {}
        self.in_defs = 0
        self.paths = None
        self.path = None
        self.depth = 0
        self.indent = '    '

    def _print(self, *args):
        return
        if args:
            print self.depth * self.indent + args[0],
        for s in args[1:]:
            print s,
        print

    def parse_transform(self, trafo_string):
        trafo = self.trafo
        #print trafo
        while trafo_string:
            #print trafo_string
            match = rx_trafo.match(trafo_string)
            if match:
                function = match.group(1)
                args = string.translate(match.group(2), commatospace)
                args = map(float, split(args))
                trafo_string = trafo_string[match.end(0):]
                if function == 'matrix':
                    trafo = trafo(apply(Trafo, tuple(args)))
                elif function == 'scale':
                    trafo = trafo(Scale(args[0]))
                elif function == 'translate':
                    dx, dy = args
                    trafo = trafo(Translation(dx, dy))
                elif function == 'rotate':
                    trafo = trafo(Rotation(args[0] * degrees))
                elif function == 'skewX':
                    trafo = trafo(Trafo(1, 0, tan(args[0] * degrees), 1, 0, 0))
                elif function == 'skewY':
                    trafo = trafo(Trafo(1, tan(args[0] * degrees), 0, 1, 0, 0))
            else:
                trafo_string = ''
        #print trafo
        self.trafo = trafo

    def startElement(self, name, attrs):
        self._print('(', name)
        for key, value in attrs.items():
            self._print('  -', key, `value`)
        self.depth = self.depth + 1
        self.push_state()
        if attrs.has_key('transform'):
            self.parse_transform(attrs['transform'])
        method = self.dispatch_start.get(name)
        if method is not None:
            getattr(self, method)(attrs)
        
    def endElement(self, name):
        self.depth = self.depth - 1
        self._print(')', name)
        method = self.dispatch_end.get(name)
        if method is not None:
            getattr(self, method)()
        self.pop_state()

    def error(self, exception):
        print 'error', exception

    def fatalError(self, exception):
        print 'fatalError', exception

    def warning(self, exception):
        print 'warning', exception

    def initsvg(self, attrs):
        width = length(attrs['width'])
        height = length(attrs['height'])
        self._print('initsvg', width, height)
        self.trafo = Trafo(1, 0, 0, -1, 0, height)
        self.basetrafo = self.trafo

    def parse_style(self, style):
        #print 'parse_style'
        parts = map(strip, split(style, ';'))
        #print parts
        for part in parts:
            key, val = map(strip, split(part, ':', 1))
            self._print('style', key, val)
            if key == 'fill':
                if val == 'none':
                    self.style.fill_pattern = EmptyPattern
                else:
                    color = csscolor(val)
                    self._print('fill', color)
                    self.style.fill_pattern = SolidPattern(color)
            elif key == 'stroke':
                if val == 'none':
                    self.style.line_pattern = EmptyPattern
                else:
                    color = csscolor(val)
                    self._print('stroke', color)
                    self.style.line_pattern = SolidPattern(color)
            elif key == 'stroke-width':
                width = length(val)
                self._print('width', width)
                self.style.line_width = width
            elif key == 'stroke-linejoin':
                self.style.line_join = join[val]
            elif key == 'stroke-linecap':
                self.style.line_cap = cap[val]

    def push_state(self):
        self.state_stack = self.style, self.trafo, self.state_stack
        self.style = self.style.Copy()
        
    def pop_state(self):
        self.style, self.trafo, self.state_stack = self.state_stack

    def point(self, x, y, relative = 0):
        x = strip(x)
        y = strip(y)
        
        xunit = x[-2:]
        factor = factors.get(xunit)
        if factor is not None:
            x = x[:-2]
        elif x[-1] == '%':
            # XXX this is wrong
            x = x[:-1]
            xunit = '%'
            factor = 1
        else:
            xunit = ''
            factor = 1.0
        x = float(x) * factor
        
        yunit = y[-2:]
        factor = factors.get(yunit)
        if factor is not None:
            y = y[:-2]
        elif y[-1] == '%':
            y = y[:-1]
            yunit = '%'
            factor = 0.01
        else:
            yunit = ''
            factor = 1.0
        y = float(y) * factor

        if xunit:
            if yunit:
                if relative:
                    p = self.basetrafo.DTransform(x, y)
                else:
                    p = self.basetrafo(x, y)
            else:
                # XXX ugly special case 
                pass
        else:
            if yunit:
                # XXX ugly special case 
                pass
            else:
                if relative:
                    p = self.trafo.DTransform(x, y)
                else:
                    p = self.trafo(x, y)
        return p

    def circle(self, attrs):
        if self.in_defs:
            id = attrs.get('id', '')
            if id:
                self.named_objects[id] = ('object', 'circle', attrs)
            return
        if attrs.has_key('cx'):
            x = attrs['cx']
        else:
            x = '0'
        if attrs.has_key('cy'):
            y = attrs['cy']
        else:
            y = '0'
        x, y = self.point(x, y)
        r = self.point(attrs['r'], '0', relative = 1).x
        t = Trafo(r, 0, 0, r, x, y)
        self._print('circle', t)
        style = attrs.get('style', '')
        if style:
            self.parse_style(style)
        self.loader.style = self.style
        apply(self.loader.ellipse, t.coeff())
            

    def ellipse(self, attrs):
        if self.in_defs:
            id = attrs.get('id', '')
            if id:
                self.named_objects[id] = ('object', 'ellipse', attrs)
            return
        if attrs.has_key('cx'):
            x = attrs['cx']
        else:
            x = '0'
        if attrs.has_key('cy'):
            y = attrs['cy']
        else:
            y = '0'
        x, y = self.point(x, y)
        rx, ry = self.point(attrs['rx'], attrs['ry'], relative = 1)
        t = Trafo(rx, 0, 0, ry, x, y)
        self._print('ellipse', t)
        style = attrs.get('style', '')
        if style:
            self.parse_style(style)
        self.loader.style = self.style
        apply(self.loader.ellipse, t.coeff())

    def rect(self, attrs):
        #print 'rect', attrs.map
        if self.in_defs:
            id = attrs.get('id', '')
            if id:
                self.named_objects[id] = ('object', 'rect', attrs)
            return
        if attrs.has_key('x'):
            x = attrs['x']
        else:
            x = '0'
        if attrs.has_key('y'):
            y = attrs['y']
        else:
            y = '0'
        x, y = self.point(x, y)
        wx, wy = self.point(attrs['width'], "0", relative = 1)
        hx, hy = self.point("0", attrs['height'], relative = 1)
        t = Trafo(wx, wy, hx, hy, x, y)
        self._print('rect', t)
        style = attrs.get('style', '')
        if style:
            self.parse_style(style)
        self.loader.style = self.style
        apply(self.loader.rectangle, t.coeff())

    def polyline(self, attrs):
        if self.in_defs:
            id = attrs.get('id', '')
            if id:
                self.named_objects[id] = ('object', 'polyline', attrs)
            return
        points = attrs['points']
        points = string.translate(points, commatospace)
        points = split(points)
        path = CreatePath()
        point = self.point
        for i in range(0, len(points), 2):
            path.AppendLine(point(points[i], points[i + 1]))
        style = attrs.get('style', '')
        if style:
            self.parse_style(style)
        self.loader.style = self.style
        self.loader.bezier(paths = (path,))

    def polygon(self, attrs):
        if self.in_defs:
            id = attrs.get('id', '')
            if id:
                self.named_objects[id] = ('object', 'polygon', attrs)
            return
        points = attrs['points']
        points = string.translate(points, commatospace)
        points = split(points)
        path = CreatePath()
        point = self.point
        for i in range(0, len(points), 2):
            path.AppendLine(point(points[i], points[i + 1]))
        path.AppendLine(path.Node(0))
        path.ClosePath()
        style = attrs.get('style', '')
        if style:
            self.parse_style(style)
        self.loader.style = self.style
        self.loader.bezier(paths = (path,))

    def parse_path(self, str):
        paths = self.paths
        path = self.path
        trafo = self.trafo
        str = strip(string.translate(str, commatospace))
        last_quad = None
        last_cmd = cmd = None
        f13 = 1.0 / 3.0; f23 = 2.0 / 3.0
        #print '*', str
        while 1:
            match = rx_command.match(str)
            #print match
            if match:
                last_cmd = cmd
                cmd = str[0]
                str = str[match.end():]
                #print '*', str
                points = match.group(1)
                #print '**', points
                if points:
                    points = map(float, split(points))
                #print cmd, points
                if cmd in 'mM':
                    path = CreatePath()
                    paths.append(path)
                    if cmd == 'M' or len(paths) == 1:
                        path.AppendLine(trafo(points[0], points[1]))
                    else:
                        p = trafo.DTransform(points[0], points[1])
                        path.AppendLine(paths[-2].Node(-1) + p)
                    if len(points) > 2:
                        if cmd == 'm':
                            for i in range(2, len(points), 2):
                                p = trafo.DTransform(points[i], points[i + 1])
                                path.AppendLine(path.Node(-1) + p)
                        else:
                            for i in range(2, len(points), 2):
                                path.AppendLine(trafo(points[i], points[i+1]))
                elif cmd == 'l':
                    for i in range(0, len(points), 2):
                        p = trafo.DTransform(points[i], points[i + 1])
                        path.AppendLine(path.Node(-1) + p)
                elif cmd == 'L':
                    for i in range(0, len(points), 2):
                        path.AppendLine(trafo(points[i], points[i+1]))
                elif cmd =='H':
                    for num in points:
                        path.AppendLine(Point(num, path.Node(-1).y))
                elif cmd =='h':
                    for num in points:
                        x, y = path.Node(-1)
                        dx, dy = trafo.DTransform(num, 0)
                        path.AppendLine(Point(x + dx, y + dy))
                elif cmd =='V':
                    for num in points:
                        path.AppendLine(Point(path.Node(-1).x, num))
                elif cmd =='v':
                    for num in points:
                        x, y = path.Node(-1)
                        dx, dy = trafo.DTransform(0, num)
                        path.AppendLine(Point(x + dx, y + dy))
                elif cmd == 'C':
                    if len(points) % 6 != 0:
                        self.loader.add_message("number of parameters of 'C'"\
                                                "must be multiple of 6")
                    else:
                        for i in range(0, len(points), 6):
                            p1 = trafo(points[i], points[i + 1])
                            p2 = trafo(points[i + 2], points[i + 3])
                            p3 = trafo(points[i + 4], points[i + 5])
                            path.AppendBezier(p1, p2, p3)
                elif cmd == 'c':
                    if len(points) % 6 != 0:
                        self.loader.add_message("number of parameters of 'c'"\
                                                "must be multiple of 6")
                    else:
                        for i in range(0, len(points), 6):
                            p = path.Node(-1)
                            p1 = p + trafo.DTransform(points[i], points[i + 1])
                            p2 = p + trafo.DTransform(points[i+2], points[i+3])
                            p3 = p + trafo.DTransform(points[i+4], points[i+5])
                            path.AppendBezier(p1, p2, p3)
                elif cmd == 'S':
                    if len(points) % 4 != 0:
                        self.loader.add_message("number of parameters of 'S'"\
                                                "must be multiple of 4")
                    else:
                        for i in range(0, len(points), 4):
                            type, controls, p, cont = path.Segment(-1)
                            if type == Bezier:
                                q = controls[1]
                            else:
                                q = p
                            p1 = 2 * p - q
                            p2 = trafo(points[i], points[i + 1])
                            p3 = trafo(points[i + 2], points[i + 3])
                            path.AppendBezier(p1, p2, p3)
                elif cmd == 's':
                    if len(points) % 4 != 0:
                        self.loader.add_message("number of parameters of 's'"\
                                                "must be multiple of 4")
                    else:
                        for i in range(0, len(points), 4):
                            type, controls, p, cont = path.Segment(-1)
                            if type == Bezier:
                                q = controls[1]
                            else:
                                q = p
                            p1 = 2 * p - q
                            p2 = p + trafo.DTransform(points[i], points[i + 1])
                            p3 = p + trafo.DTransform(points[i+2], points[i+3])
                            path.AppendBezier(p1, p2, p3)
                elif cmd == 'Q':
                    if len(points) % 4 != 0:
                        self.loader.add_message("number of parameters of 'Q'"\
                                                "must be multiple of 4")
                    else:
                        for i in range(0, len(points), 4):
                            q = trafo(points[i], points[i + 1])
                            p3 = trafo(points[i + 2], points[i + 3])
                            p1 = f13 * path.Node(-1) + f23 * q
                            p2 = f13 * p3 + f23 * q
                            path.AppendBezier(p1, p2, p3)
                            last_quad = q
                elif cmd == 'q':
                    if len(points) % 4 != 0:
                        self.loader.add_message("number of parameters of 'q'"\
                                                "must be multiple of 4")
                    else:
                        for i in range(0, len(points), 4):
                            p = path.Node(-1)
                            q = p + trafo.DTransform(points[i], points[i + 1])
                            p3 = p + trafo.DTransform(points[i+2], points[i+3])
                            p1 = f13 * p + f23 * q
                            p2 = f13 * p3 + f23 * q
                            path.AppendBezier(p1, p2, p3)
                            last_quad = q
                elif cmd == 'T':
                    if len(points) % 2 != 0:
                        self.loader.add_message("number of parameters of 'T'"\
                                                "must be multiple of 4")
                    else:
                        if last_cmd not in 'QqTt' or last_quad is None:
                            last_quad = path.Node(-1)
                        for i in range(0, len(points), 2):
                            p = path.Node(-1)
                            q = 2 * p - last_quad
                            p3 = trafo(points[i], points[i + 1])
                            p1 = f13 * p + f23 * q
                            p2 = f13 * p3 + f23 * q
                            path.AppendBezier(p1, p2, p3)
                            last_quad = q
                elif cmd == 't':
                    if len(points) % 2 != 0:
                        self.loader.add_message("number of parameters of 't'"\
                                                "must be multiple of 4")
                    else:
                        if last_cmd not in 'QqTt' or last_quad is None:
                            last_quad = path.Node(-1)
                        for i in range(0, len(points), 2):
                            p = path.Node(-1)
                            q = 2 * p - last_quad
                            p3 = p + trafo.DTransform(points[i], points[i + 1])
                            p1 = f13 * p + f23 * q
                            p2 = f13 * p3 + f23 * q
                            path.AppendBezier(p1, p2, p3)
                            last_quad = q

                elif cmd == 'z':
                    path.AppendLine(path.Node(0))
                    path.ClosePath()
            else:
                break
        self.path = path

    def begin_path(self, attrs):
        self.paths = []
        self.path = None
        self.parse_path(attrs['d'])
        style = attrs.get('style', '')
        if style:
            self.parse_style(style)
        self.loader.style = self.style
        
    def end_path(self):
        self.loader.bezier(paths = tuple(self.paths))
        self.paths = None
        
    def data(self, attrs):
        pass

    def begin_group(self, attrs):
        style = attrs.get('style', '')
        if style:
            self.parse_style(style)
        self.loader.begin_group()
        
    def end_group(self):
        try:
            self.loader.end_group()
        except EmptyCompositeError:
            pass

    def use(self, attrs):
        #print 'use', attrs.map
        if attrs.has_key('xlink:href'):
            name = attrs['xlink:href']
        else:
            name = attrs.get('href', '<none>')
        if name:
            data = self.named_objects.get(name[1:])
            #print name, data
            if data[0] == 'object':
                if attrs.has_key('style'):
                    self.parse_style(attrs['style'])
                self.startElement(data[1], data[2])
                self.endElement(data[1])
            

    def begin_defs(self, attrs):
        self.in_defs = 1

    def end_defs(self):
        self.in_defs = 0

class SVGLoader(GenericLoader):

    format_name = format_name

    def __init__(self, file, filename, match):
	GenericLoader.__init__(self, file, filename, match)
        
    def __del__(self):
	pass

    def Load(self):
        try:
            self.document()
            self.layer()
            
            parser = saxexts.make_parser()
            parser.setDocumentHandler(SVGHandler(self))
            try:
                self.file.seek(0)
                file = self.file
            except:
                file = streamfilter.StringDecode(self.match.string, self.file)
            parser.parseFile(file)
            
            self.end_all()
            self.object.load_Completed()
            return self.object
        except:
            warn_tb('INTERNAL')
            raise
