changeset 54:6d2aaba7ac4d

Tree - Move serialization code into classes
author Bastian Blank <bblank@thinkmo.de>
date Mon, 20 Jul 2009 12:00:06 +0200
parents d071598a93ef
children 125ce968352d
files emeraldtree/tests/test_tree.py emeraldtree/tree.py
diffstat 2 files changed, 245 insertions(+), 303 deletions(-) [+]
line wrap: on
line diff
--- a/emeraldtree/tests/test_tree.py	Mon Jul 20 11:08:47 2009 +0200
+++ b/emeraldtree/tests/test_tree.py	Mon Jul 20 12:00:06 2009 +0200
@@ -1,11 +1,10 @@
 import py.test
 from emeraldtree.tree import *
 
-def serialize(elem, **options):
-    from cStringIO import StringIO
+def serialize(elem, namespaces={}):
+    from StringIO import StringIO
     file = StringIO()
-    tree = ElementTree(elem)
-    tree.write(file, **options)
+    XMLWriter(namespaces=namespaces).write(file.write, elem)
     return file.getvalue()
 
 def test_Element():
@@ -255,19 +254,20 @@
     assert elem[1].tag == 'c'
     assert elem[2] == 'd'
 
-def test_XMLParser_namespace():
+def test_XMLParser_namespace_1():
     elem = XML('<b xmlns="c" d="e"/>')
     assert isinstance(elem.tag, QName)
     assert elem.tag == QName('b', 'c')
     assert elem.attrib == {QName('d', None): 'e'}
     assert serialize(elem) == '<ns0:b d="e" xmlns:ns0="c" />'
-    assert serialize(elem, default_namespace='c') == '<b d="e" xmlns="c" />'
+    assert serialize(elem, namespaces={'c': ''}) == '<b d="e" xmlns="c" />'
 
+def test_XMLParser_namespace_2():
     elem = XML('<a:b xmlns:a="c" d="e" a:f="g"/>')
     assert isinstance(elem.tag, QName)
     assert elem.tag == QName('b', 'c')
     assert elem.attrib == {'d': 'e', QName('f', 'c'): 'g'}
     assert serialize(elem) == '<ns0:b d="e" ns0:f="g" xmlns:ns0="c" />'
-    assert serialize(elem, default_namespace='c') == '<b d="e" f="g" xmlns="c" />'
+    assert serialize(elem, namespaces={'c': ''}) == '<b d="e" f="g" xmlns="c" />'
 
 
--- a/emeraldtree/tree.py	Mon Jul 20 11:08:47 2009 +0200
+++ b/emeraldtree/tree.py	Mon Jul 20 12:00:06 2009 +0200
@@ -1,58 +1,3 @@
-#
-# ElementTree
-# $Id: ElementTree.py 3276 2007-09-12 06:52:30Z fredrik $
-#
-# light-weight XML support for Python 2.2 and later.
-#
-# history:
-# 2001-10-20 fl   created (from various sources)
-# 2001-11-01 fl   return root from parse method
-# 2002-02-16 fl   sort attributes in lexical order
-# 2002-04-06 fl   TreeBuilder refactoring, added PythonDoc markup
-# 2002-05-01 fl   finished TreeBuilder refactoring
-# 2002-07-14 fl   added basic namespace support to ElementTree.write
-# 2002-07-25 fl   added QName attribute support
-# 2002-10-20 fl   fixed encoding in write
-# 2002-11-24 fl   changed default encoding to ascii; fixed attribute encoding
-# 2002-11-27 fl   accept file objects or file names for parse/write
-# 2002-12-04 fl   moved XMLTreeBuilder back to this module
-# 2003-01-11 fl   fixed entity encoding glitch for us-ascii
-# 2003-02-13 fl   added XML literal factory
-# 2003-02-21 fl   added ProcessingInstruction/PI factory
-# 2003-05-11 fl   added tostring/fromstring helpers
-# 2003-05-26 fl   added ElementPath support
-# 2003-07-05 fl   added makeelement factory method
-# 2003-07-28 fl   added more well-known namespace prefixes
-# 2003-08-15 fl   fixed typo in ElementTree.findtext (Thomas Dartsch)
-# 2003-09-04 fl   fall back on emulator if ElementPath is not installed
-# 2003-10-31 fl   markup updates
-# 2003-11-15 fl   fixed nested namespace bug
-# 2004-03-28 fl   added XMLID helper
-# 2004-06-02 fl   added default support to findtext
-# 2004-06-08 fl   fixed encoding of non-ascii element/attribute names
-# 2004-08-23 fl   take advantage of post-2.1 expat features
-# 2004-09-03 fl   made Element class visible; removed factory
-# 2005-02-01 fl   added iterparse implementation
-# 2005-03-02 fl   fixed iterparse support for pre-2.2 versions
-# 2005-11-12 fl   added tostringlist/fromstringlist helpers
-# 2006-07-05 fl   merged in selected changes from the 1.3 sandbox
-# 2006-07-05 fl   removed support for 2.1 and earlier
-# 2007-06-21 fl   added deprecation/future warnings
-# 2007-08-25 fl   added doctype hook, added parser version attribute etc
-# 2007-08-26 fl   added new serializer code (better namespace handling, etc)
-# 2007-08-27 fl   warn for broken /tag searches on tree level
-# 2007-09-02 fl   added html/text methods to serializer (experimental)
-# 2007-09-05 fl   added method argument to tostring/tostringlist
-# 2007-09-06 fl   improved error handling
-#
-# Copyright (c) 1999-2007 by Fredrik Lundh.  All rights reserved.
-#
-# fredrik@pythonware.com
-# http://www.pythonware.com
-#
-# --------------------------------------------------------------------
-# The ElementTree toolkit is
-#
 # Copyright (c) 1999-2007 by Fredrik Lundh
 #               2008 Bastian Blank <bblank@thinkmo.de>
 #
@@ -77,7 +22,6 @@
 # WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
 # ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
 # OF THIS SOFTWARE.
-# --------------------------------------------------------------------
 
 from __future__ import generators
 
@@ -96,7 +40,7 @@
     "tostring", "tostringlist",
     "TreeBuilder",
     "XML",
-    "XMLParser",
+    "XMLParser", "XMLWriter",
     ]
 
 ##
@@ -675,217 +619,29 @@
               xml_declaration=None,
               default_namespace=None,
               method=None,
-              namespaces=None):
+              namespaces={}):
         assert self._root is not None
         if not hasattr(file, "write"):
             file = open(file, "wb")
         write = file.write
-        if not method:
-            method = "xml"
-        elif method == 'html':
-            default_namespace = "http://www.w3.org/1999/xhtml"
         if not encoding:
             encoding = "us-ascii"
-        elif xml_declaration or (xml_declaration is None and
-                                 encoding not in ("utf-8", "us-ascii")):
-            write("<?xml version='1.0' encoding='%s'?>\n" % encoding)
-        if method == "text":
-            _serialize_text(write, self._root, encoding)
+        if not method or method == "xml":
+            Writer = XMLWriter
+        elif method == "html":
+            Writer = HTMLWriter
         else:
-            qnames, namespaces = _namespaces(
-                self._root, encoding, default_namespace, namespaces
-                )
-            if method == "xml":
-                _serialize_xml(
-                    write, self._root, encoding, qnames, namespaces
-                    )
-            elif method == "html":
-                _serialize_html(
-                    write, self._root, encoding, qnames, namespaces
-                    )
-            else:
-                raise ValueError("unknown method %r" % method)
+            Writer = TextWriter
+
+        if default_namespace:
+            namespaces = namespaces.copy()
+            namespaces[default_namespace] = ''
+
+        Writer(encoding, namespaces).write(write, self._root)
 
 # --------------------------------------------------------------------
 # serialization support
 
-def _namespaces(elem, encoding, default_namespace, namespaces):
-    # identify namespaces used in this tree
-
-    # maps qnames to *encoded* prefix:local names
-    qnames = {None: None}
-
-    # maps uri:s to prefixes
-    candidate_namespaces = _namespace_map.copy()
-    if namespaces:
-        candidate_namespaces.update(namespaces)
-    if default_namespace:
-        candidate_namespaces[default_namespace] = ""
-    used_namespaces = {}
-
-    def encode(text):
-        return text.encode(encoding)
-
-    def add_qname(qname):
-        if qname in qnames:
-            return
-
-        # calculate serialized qname representation
-        try:
-            if qname.uri is not None:
-                uri = qname.uri
-                prefix = used_namespaces.get(uri, None)
-                if prefix is None:
-                    prefix = candidate_namespaces.get(uri, None)
-                    if prefix is None:
-                        prefix = "ns%d" % len(used_namespaces)
-                    if prefix != "xml":
-                        used_namespaces[uri] = prefix
-                if prefix:
-                    qnames[qname] = encode("%s:%s" % (prefix, qname.name))
-                else:
-                    qnames[qname] = encode(qname.name)
-            else:
-                # XXX: What happens with undefined namespace?
-                qnames[qname] = encode(qname.name)
-        except TypeError:
-            _raise_serialization_error(qname)
-
-    # populate qname and namespaces table
-    if isinstance(elem, Element):
-        for elem in elem.iter():
-            if isinstance(elem, Element):
-                tag = elem.tag
-                if isinstance(tag, QName):
-                    add_qname(tag)
-                elif isinstance(tag, basestring):
-                    add_qname(QName(tag))
-                elif tag is not None:
-                    _raise_serialization_error(tag)
-
-                for key in elem.keys():
-                    if isinstance(key, QName):
-                        add_qname(key)
-                    elif isinstance(key, basestring):
-                        add_qname(QName(key))
-                    elif key is not None:
-                        _raise_serialization_error(key)
-
-    return qnames, used_namespaces
-
-def _serialize_xml(write, elem, encoding, qnames, namespaces):
-    if isinstance(elem, Element):
-        tag = qnames[elem.tag]
-
-        if tag is not None:
-            write("<" + tag)
-
-            if elem.attrib:
-                items = elem.attrib.items()
-                items.sort(key=lambda x: x[0])
-                for k, v in items:
-                    k = qnames[k]
-                    if isinstance(v, QName):
-                        v = qnames[v]
-                    else:
-                        v = _escape_attrib(v, encoding)
-                    write(' ' + k + '="' + v + '"')
-            if namespaces:
-                items = namespaces.items()
-                items.sort(key=lambda x: x[1]) # sort on prefix
-                for v, k in items:
-                    if k:
-                        k = ":" + k
-                    write(" xmlns%s=\"%s\"" % (
-                        k.encode(encoding),
-                        _escape_attrib(v, encoding)
-                        ))
-            if len(elem):
-                write(">")
-                for e in elem:
-                    _serialize_xml(write, e, encoding, qnames, None)
-                write("</" + tag + ">")
-            else:
-                write(" />")
-
-        else:
-            for e in elem:
-                _serialize_xml(write, e, encoding, qnames, None)
-
-    elif isinstance(elem, Comment):
-        write("<!--%s-->" % _escape_cdata(elem.text, encoding))
-
-    elif isinstance(elem, ProcessingInstruction):
-        text = _escape_cdata(elem.target, encoding)
-        if elem.text is not None:
-            text += ' ' + _escape_cdata(elem.text, encoding)
-        write("<?%s?>" % text)
-
-    else:
-        write(_escape_cdata(unicode(elem), encoding))
-
-HTML_EMPTY = set(("area", "base", "basefont", "br", "col", "frame", "hr",
-                  "img", "input", "isindex", "link", "meta" "param"))
-
-def _serialize_html(write, elem, encoding, qnames, namespaces):
-    if isinstance(elem, Element):
-        tag = qnames[elem.tag]
-
-        if tag is not None:
-            write("<" + tag)
-
-            if elem.attrib:
-                items = elem.attrib.items()
-                items.sort(key=lambda x: x[0])
-                for k, v in items:
-                    k = qnames[k]
-                    if isinstance(v, QName):
-                        v = qnames[v]
-                    else:
-                        v = _escape_attrib(v, encoding)
-                    # FIXME: handle boolean attributes
-                    write(' ' + k + '="' + v + '"')
-            if namespaces:
-                items = namespaces.items()
-                items.sort(key=lambda x: x[1]) # sort on prefix
-                for v, k in items:
-                    if k:
-                        k = ":" + k
-                    write(" xmlns%s=\"%s\"" % (
-                        k.encode(encoding),
-                        _escape_attrib(v, encoding)
-                        ))
-            write(">")
-
-            if tag.lower() in ('script', 'style'):
-                write(_encode(''.join(elem.itertext()), encoding))
-            else:
-                for e in elem:
-                    _serialize_html(write, e, encoding, qnames, None)
-
-            if tag not in HTML_EMPTY:
-                write("</" + tag + ">")
-
-        else:
-            for e in elem:
-                _serialize_html(write, e, encoding, qnames, None)
-
-    elif isinstance(elem, Comment):
-        write("<!--%s-->" % _escape_cdata(elem.text, encoding))
-
-    elif isinstance(elem, ProcessingInstruction):
-        text = _escape_cdata(elem.target, encoding)
-        if elem.text is not None:
-            text += ' ' + _escape_cdata(elem.text, encoding)
-        write("<?%s?>" % text)
-
-    else:
-        write(_escape_cdata(elem, encoding))
-
-def _serialize_text(write, elem, encoding):
-    for part in elem.itertext():
-        write(part.encode(encoding))
-
 ##
 # Registers a namespace prefix.  The registry is global, and any
 # existing mapping for either the given prefix or the namespace URI
@@ -924,45 +680,6 @@
         "cannot serialize %r (type %s)" % (text, type(text).__name__)
         )
 
-def _encode(text, encoding):
-    try:
-        return text.encode(encoding, "xmlcharrefreplace")
-    except (TypeError, AttributeError):
-        _raise_serialization_error(text)
-
-def _escape_cdata(text, encoding):
-    # escape character data
-    try:
-        # it's worth avoiding do-nothing calls for strings that are
-        # shorter than 500 character, or so.  assume that's, by far,
-        # the most common case in most applications.
-        if "&" in text:
-            text = text.replace("&", "&amp;")
-        if "<" in text:
-            text = text.replace("<", "&lt;")
-        if ">" in text:
-            text = text.replace(">", "&gt;")
-        return text.encode(encoding, "xmlcharrefreplace")
-    except (TypeError, AttributeError):
-        _raise_serialization_error(text)
-
-def _escape_attrib(text, encoding):
-    # escape attribute value
-    try:
-        if "&" in text:
-            text = text.replace("&", "&amp;")
-        if "<" in text:
-            text = text.replace("<", "&lt;")
-        if ">" in text:
-            text = text.replace(">", "&gt;")
-        if "\"" in text:
-            text = text.replace("\"", "&quot;")
-        if "\n" in text:
-            text = text.replace("\n", "&#10;")
-        return text.encode(encoding, "xmlcharrefreplace")
-    except (TypeError, AttributeError):
-        _raise_serialization_error(text)
-
 # --------------------------------------------------------------------
 
 ##
@@ -1434,3 +1151,228 @@
         del self.target, self._parser # get rid of circular references
         return tree
 
+class BaseWriter(object):
+    def __init__(self, encoding=None, namespaces={}):
+        self.encoding = encoding
+        self.namespaces = namespaces
+
+    def _encode(self, text):
+        if self.encoding:
+            return text.encode(self.encoding, "xmlcharrefreplace")
+        return text
+
+    def _escape_cdata(self, text):
+        # escape character data
+        # it's worth avoiding do-nothing calls for strings that are
+        # shorter than 500 character, or so.  assume that's, by far,
+        # the most common case in most applications.
+        if "&" in text:
+            text = text.replace("&", "&amp;")
+        if "<" in text:
+            text = text.replace("<", "&lt;")
+        if ">" in text:
+            text = text.replace(">", "&gt;")
+        return self._encode(text)
+
+    def _escape_attrib(self, text):
+        # escape attribute value
+        if "\"" in text:
+            text = text.replace("\"", "&quot;")
+        if "\n" in text:
+            text = text.replace("\n", "&#10;")
+        return self._escape_cdata(text)
+
+    def _namespaces(self, elem):
+        # identify namespaces used in this tree
+
+        # maps qnames to *encoded* prefix:local names
+        qnames = {None: None}
+
+        # maps uri:s to prefixes
+        candidate_namespaces = _namespace_map.copy()
+        candidate_namespaces = {}
+        candidate_namespaces.update(self.namespaces)
+        used_namespaces = {}
+
+        def add_qname(qname):
+            if qname in qnames:
+                return
+
+            # calculate serialized qname representation
+            try:
+                if qname.uri is not None:
+                    uri = qname.uri
+                    prefix = used_namespaces.get(uri, None)
+                    if prefix is None:
+                        prefix = candidate_namespaces.get(uri, None)
+                        if prefix is None:
+                            prefix = "ns%d" % len(used_namespaces)
+                        if prefix != "xml":
+                            used_namespaces[uri] = prefix
+                    if prefix:
+                        qnames[qname] = "%s:%s" % (prefix, qname.name)
+                    else:
+                        qnames[qname] = qname.name
+                else:
+                    # XXX: What happens with undefined namespace?
+                    qnames[qname] = qname.name
+            except TypeError:
+                _raise_serialization_error(qname)
+
+        # populate qname and namespaces table
+        if isinstance(elem, Element):
+            for elem in elem.iter():
+                if isinstance(elem, Element):
+                    tag = elem.tag
+                    if isinstance(tag, QName):
+                        add_qname(tag)
+                    elif isinstance(tag, basestring):
+                        add_qname(QName(tag))
+                    elif tag is not None:
+                        _raise_serialization_error(tag)
+
+                    for key in elem.keys():
+                        if isinstance(key, QName):
+                            add_qname(key)
+                        elif isinstance(key, basestring):
+                            add_qname(QName(key))
+                        elif key is not None:
+                            _raise_serialization_error(key)
+
+        return qnames, used_namespaces
+
+    def serialize_start(self, write):
+        pass
+
+    def write(self, write, element):
+        qnames, namespaces = self._namespaces(element)
+        self.serialize_start(write)
+        self.serialize(write, element, qnames, namespaces)
+
+
+class TextWriter(BaseWriter):
+    def serialize(self, write, elem, qnames=None, namespaces=None):
+        for part in elem.itertext():
+            write(self._encode(part))
+
+
+class XMLWriter(BaseWriter):
+    def serialize(self, write, elem, qnames, namespaces={}):
+        if isinstance(elem, Element):
+            tag = qnames[elem.tag]
+
+            if tag is not None:
+                write("<" + tag)
+
+                if elem.attrib:
+                    items = elem.attrib.items()
+                    items.sort(key=lambda x: x[0])
+                    for k, v in items:
+                        k = qnames[k]
+                        if isinstance(v, QName):
+                            v = qnames[v]
+                        else:
+                            v = self._escape_attrib(unicode(v))
+                        write(' ' + k + '="' + v + '"')
+                if namespaces:
+                    items = namespaces.items()
+                    items.sort(key=lambda x: x[1]) # sort on prefix
+                    for v, k in items:
+                        if k:
+                            k = ":" + k
+                        write(" xmlns%s=\"%s\"" % (
+                            self._encode(k),
+                            self._escape_attrib(v)
+                            ))
+                if len(elem):
+                    write(">")
+                    for e in elem:
+                        self.serialize(write, e, qnames)
+                    write("</" + tag + ">")
+                else:
+                    write(" />")
+
+            else:
+                for e in elem:
+                    self.serialize(write, e, encoding, qnames)
+
+        elif isinstance(elem, Comment):
+            write("<!--%s-->" % self._escape_cdata(elem.text))
+
+        elif isinstance(elem, ProcessingInstruction):
+            text = self._escape_cdata(elem.target)
+            if elem.text is not None:
+                text += ' ' + self._escape_cdata(elem.text)
+            write("<?%s?>" % text)
+
+        else:
+            write(self._escape_cdata(unicode(elem)))
+
+    def serialize_start(self, write):
+        if self.encoding and self.encoding not in ("utf-8", "us-ascii"):
+            write("<?xml version='1.0' encoding='%s'?>\n" % self.encoding)
+
+
+class HTMLWriter(BaseWriter):
+    empty_elements = frozenset(("area", "base", "basefont", "br", "col", "frame", "hr",
+                                "img", "input", "isindex", "link", "meta" "param"))
+
+    def __init__(self, encoding=None, namespaces={}):
+        namespaces["http://www.w3.org/1999/xhtml"] = ''
+        super(HTTPWriter, self).__init__(encoding, namespaces)
+
+    def serialize(self, write, elem, qnames, namespaces={}):
+        if isinstance(elem, Element):
+            tag = qnames[elem.tag]
+
+            if tag is not None:
+                write("<" + tag)
+
+                if elem.attrib:
+                    items = elem.attrib.items()
+                    items.sort(key=lambda x: x[0])
+                    for k, v in items:
+                        k = qnames[k]
+                        if isinstance(v, QName):
+                            v = qnames[v]
+                        else:
+                            v = self._escape_attrib(unicode(v))
+                        # FIXME: handle boolean attributes
+                        write(' ' + k + '="' + v + '"')
+                if namespaces:
+                    items = namespaces.items()
+                    items.sort(key=lambda x: x[1]) # sort on prefix
+                    for v, k in items:
+                        if k:
+                            k = ":" + k
+                        write(" xmlns%s=\"%s\"" % (
+                            self._encode(k),
+                            self._escape_attrib(v)
+                            ))
+                write(">")
+
+                if tag.lower() in ('script', 'style'):
+                    write(self._encode(''.join(elem.itertext())))
+                else:
+                    for e in elem:
+                        self.serialize(write, e, encoding, qnames)
+
+                if tag not in HTML_EMPTY:
+                    write("</" + tag + ">")
+
+            else:
+                for e in elem:
+                    self.serialize(write, e, encoding, qnames)
+
+        elif isinstance(elem, Comment):
+            write("<!--%s-->" % self._escape_cdata(elem.text))
+
+        elif isinstance(elem, ProcessingInstruction):
+            text = self._escape_cdata(elem.target)
+            if elem.text is not None:
+                text += ' ' + self._escape_cdata(elem.text)
+            write("<?%s?>" % text)
+
+        else:
+            write(self._escape_cdata(elem))
+