changeset 291:edf5cc1ffd26

Provide an optional "strict" keyword flag to all YAML load functions to detect and prevent duplicate keys within a single YAML document
author Franz Glasner <f.glasner@feldmann-mg.com>
date Wed, 10 Feb 2021 14:47:41 +0100
parents aec97edf7945
children 6a044778371a
files CHANGES.txt configmix/yaml.py
diffstat 2 files changed, 143 insertions(+), 87 deletions(-) [+]
line wrap: on
line diff
--- a/CHANGES.txt	Wed Feb 10 13:43:29 2021 +0100
+++ b/CHANGES.txt	Wed Feb 10 14:47:41 2021 +0100
@@ -16,6 +16,12 @@
    :version: 0.13.dev1
    :released: n/a
 
+   .. change::
+      :tags: feature
+
+      All YAML load functions got a new optional keyword `strict` to detect
+      and prevent duplicate keys within a single YAML document.
+
 .. changelog::
    :version: 0.12
    :released: 2020-12-07
--- a/configmix/yaml.py	Wed Feb 10 13:43:29 2021 +0100
+++ b/configmix/yaml.py	Wed Feb 10 14:47:41 2021 +0100
@@ -1,6 +1,6 @@
 # -*- coding: utf-8 -*-
 # :-
-# :Copyright: (c) 2015-2020, Franz Glasner. All rights reserved.
+# :Copyright: (c) 2015-2021, Franz Glasner. All rights reserved.
 # :License:   3-clause BSD. See LICENSE.txt for details.
 # :-
 """Simple wrapper for :mod:`yaml` to support all-unicode strings when
@@ -42,58 +42,66 @@
 
     """
 
+    def __init__(self, *args, **kwds):
+        strict = kwds.pop("strict", False)
+        self.__allow_duplicate_keys = not strict
+        yaml.Loader.__init__(self, *args, **kwds)
+
     def construct_yaml_str(self, node):
         return self.construct_scalar(node)
 
-    if OrderedDict:
+    #
+    # From https://pypi.python.org/pypi/yamlordereddictloader/0.1.1
+    # (MIT License)
+    #
 
-        #
-        # From https://pypi.python.org/pypi/yamlordereddictloader/0.1.1
-        # (MIT License)
-        #
+    def construct_yaml_map(self, node):
+        data = DictImpl()
+        yield data
+        value = self.construct_mapping(node)
+        data.update(value)
 
-        def construct_yaml_map(self, node):
-            data = OrderedDict()
-            yield data
-            value = self.construct_mapping(node)
-            data.update(value)
+    def construct_mapping(self, node, deep=False):
+        if isinstance(node, yaml.MappingNode):
+            self.flatten_mapping(node)
+        else:
+            raise yaml.constructor.ConstructorError(
+                None,
+                None,
+                'expected a mapping node, but found %s' % node.id,
+                node.start_mark)
 
-        def construct_mapping(self, node, deep=False):
-            if isinstance(node, yaml.MappingNode):
-                self.flatten_mapping(node)
-            else:
+        mapping = DictImpl()
+        for key_node, value_node in node.value:
+            key = self.construct_object(key_node, deep=deep)
+            try:
+                hash(key)
+            except TypeError as err:
                 raise yaml.constructor.ConstructorError(
-                    None,
-                    None,
-                    'expected a mapping node, but found %s' % node.id,
-                    node.start_mark)
-
-            mapping = OrderedDict()
-            for key_node, value_node in node.value:
-                key = self.construct_object(key_node, deep=deep)
-                try:
-                    hash(key)
-                except TypeError as err:
-                    raise yaml.constructor.ConstructorError(
-                        'while constructing a mapping', node.start_mark,
-                        'found unacceptable key (%s)' % (err,
-                                                         key_node.start_mark)
-                    )
-                value = self.construct_object(value_node, deep=deep)
-                mapping[key] = value
-            return mapping
+                    'while constructing a mapping', node.start_mark,
+                    'found unacceptable key (%s)' % (err,
+                                                     key_node.start_mark)
+                )
+            value = self.construct_object(value_node, deep=deep)
+            if not self.__allow_duplicate_keys and key in mapping:
+                raise yaml.constructor.ConstructorError(
+                    'while constructing a mapping', node.start_mark,
+                    'found duplicate key %r (%s)' % (key,
+                                                     key_node.start_mark)
+                )
+            mapping[key] = value
+        return mapping
 
 
 ConfigLoader.add_constructor(
     u("tag:yaml.org,2002:str"),
     ConfigLoader.construct_yaml_str)
-if OrderedDict:
-    ConfigLoader.add_constructor(
-        u("tag:yaml.org,2002:map"),
-        ConfigLoader.construct_yaml_map)
-    ConfigLoader.add_constructor(
-        u("tag:yaml.org,2002:omap"),
-        ConfigLoader.construct_yaml_map)
+ConfigLoader.add_constructor(
+    u("tag:yaml.org,2002:map"),
+    ConfigLoader.construct_yaml_map)
+ConfigLoader.add_constructor(
+    u("tag:yaml.org,2002:omap"),
+    ConfigLoader.construct_yaml_map)
 
 
 class ConfigSafeLoader(yaml.SafeLoader):
@@ -108,65 +116,93 @@
 
     """
 
+    def __init__(self, *args, **kwds):
+        strict = kwds.pop("strict", False)
+        self.__allow_duplicate_keys = not strict
+        yaml.SafeLoader.__init__(self, *args, **kwds)
+
     def construct_yaml_str(self, node):
         return self.construct_scalar(node)
 
-    if OrderedDict:
+    #
+    # From https://pypi.python.org/pypi/yamlordereddictloader/0.1.1
+    # (MIT License)
+    #
 
-        #
-        # From https://pypi.python.org/pypi/yamlordereddictloader/0.1.1
-        # (MIT License)
-        #
+    def construct_yaml_map(self, node):
+        data = DictImpl()
+        yield data
+        value = self.construct_mapping(node)
+        data.update(value)
 
-        def construct_yaml_map(self, node):
-            data = OrderedDict()
-            yield data
-            value = self.construct_mapping(node)
-            data.update(value)
+    def construct_mapping(self, node, deep=False):
+        if isinstance(node, yaml.MappingNode):
+            self.flatten_mapping(node)
+        else:
+            raise yaml.constructor.ConstructorError(
+                None,
+                None,
+                'expected a mapping node, but found %s' % node.id,
+                node.start_mark)
 
-        def construct_mapping(self, node, deep=False):
-            if isinstance(node, yaml.MappingNode):
-                self.flatten_mapping(node)
-            else:
+        mapping = DictImpl()
+        for key_node, value_node in node.value:
+            key = self.construct_object(key_node, deep=deep)
+            try:
+                hash(key)
+            except TypeError as err:
                 raise yaml.constructor.ConstructorError(
-                    None,
-                    None,
-                    'expected a mapping node, but found %s' % node.id,
-                    node.start_mark)
-
-            mapping = OrderedDict()
-            for key_node, value_node in node.value:
-                key = self.construct_object(key_node, deep=deep)
-                try:
-                    hash(key)
-                except TypeError as err:
-                    raise yaml.constructor.ConstructorError(
-                        'while constructing a mapping', node.start_mark,
-                        'found unacceptable key (%s)' % (err,
-                                                         key_node.start_mark)
-                    )
-                value = self.construct_object(value_node, deep=deep)
-                mapping[key] = value
-            return mapping
+                    'while constructing a mapping', node.start_mark,
+                    'found unacceptable key (%s)' % (err,
+                                                     key_node.start_mark)
+                )
+            value = self.construct_object(value_node, deep=deep)
+            if not self.__allow_duplicate_keys and key in mapping:
+                raise yaml.constructor.ConstructorError(
+                    'while constructing a mapping', node.start_mark,
+                    'found duplicate key %r (%s)' % (key,
+                                                     key_node.start_mark)
+                )
+            mapping[key] = value
+        return mapping
 
 
 ConfigSafeLoader.add_constructor(
     u("tag:yaml.org,2002:str"),
     ConfigSafeLoader.construct_yaml_str)
-if OrderedDict:
-    ConfigSafeLoader.add_constructor(
-        u("tag:yaml.org,2002:map"),
-        ConfigSafeLoader.construct_yaml_map)
-    ConfigSafeLoader.add_constructor(
-        u("tag:yaml.org,2002:omap"),
-        ConfigSafeLoader.construct_yaml_map)
+ConfigSafeLoader.add_constructor(
+    u("tag:yaml.org,2002:map"),
+    ConfigSafeLoader.construct_yaml_map)
+ConfigSafeLoader.add_constructor(
+    u("tag:yaml.org,2002:omap"),
+    ConfigSafeLoader.construct_yaml_map)
 
 
-def load(stream, Loader=ConfigLoader):
+def config_loader_factory(strict=False):
+    def _real_factory(*args, **kwds):
+        kwds["strict"] = strict
+        return ConfigLoader(*args, **kwds)
+    return _real_factory
+
+
+def config_safe_loader_factory(strict=False):
+    def _real_factory(*args, **kwds):
+        kwds["strict"] = strict
+        return ConfigSafeLoader(*args, **kwds)
+    return _real_factory
+
+
+def load(stream, Loader=None, strict=False):
     """Parse the given `stream` and return a Python object constructed
     from for the first document in the stream.
 
+    If `strict` is `True` then duplicate mapping keys within a YAML
+    document are detected and prevented. If a `Loader` is given then
+    `strict` does not apply.
+
     """
+    if Loader is None:
+        Loader = config_loader_factory(strict=strict)
     data = yaml.load(stream, Loader)
     # Map an empty document to an empty dict
     if data is None:
@@ -176,11 +212,17 @@
     return data
 
 
-def load_all(stream, Loader=ConfigLoader):
+def load_all(stream, Loader=None, strict=False):
     """Parse the given `stream` and return a sequence of Python objects
     corresponding to the documents in the `stream`.
 
+    If `strict` is `True` then duplicate mapping keys within a YAML
+    document are detected and prevented. If a `Loader` is given then
+    `strict` does not apply.
+
     """
+    if Loader is None:
+        Loader = config_loader_factory(strict=strict)
     data_all = yaml.load_all(stream, Loader)
     rdata = []
     for data in data_all:
@@ -193,15 +235,19 @@
     return rdata
 
 
-def safe_load(stream):
+def safe_load(stream, strict=False):
     """Parse the given `stream` and return a Python object constructed
     from for the first document in the stream.
 
     Recognizes only standard YAML tags and cannot construct an
     arbitrary Python object.
 
+    If `strict` is `True` then duplicate mapping keys within a YAML document
+    are detected and prevented.
+
     """
-    data = yaml.load(stream, Loader=ConfigSafeLoader)
+    data = yaml.load(stream,
+                     Loader=config_safe_loader_factory(strict=strict))
     # Map an empty document to an empty dict
     if data is None:
         return DictImpl()
@@ -210,14 +256,18 @@
     return data
 
 
-def safe_load_all(stream):
+def safe_load_all(stream, strict=False):
     """Return the list of all decoded YAML documents in the file `stream`.
 
     Recognizes only standard YAML tags and cannot construct an
     arbitrary Python object.
 
+    If `strict` is `True` then duplicate mapping keys within a YAML document
+    are detected and prevented.
+
     """
-    data_all = yaml.load_all(stream, Loader=ConfigSafeLoader)
+    data_all = yaml.load_all(stream,
+                             Loader=config_safe_loader_factory(strict=strict))
     rdata = []
     for data in data_all:
         if data is None: