changeset 185:3f01add9d84f

Saving and amending the timestamps databases is fully working
author Franz Glasner <hg@dom66.de>
date Sun, 09 Sep 2018 17:54:37 +0200
parents 0161a1e2ce12
children aa0a269494fc
files extensions/timestamps.py
diffstat 1 files changed, 72 insertions(+), 41 deletions(-) [+]
line wrap: on
line diff
--- a/extensions/timestamps.py	Sun Sep 09 16:37:29 2018 +0200
+++ b/extensions/timestamps.py	Sun Sep 09 17:54:37 2018 +0200
@@ -111,6 +111,7 @@
         ("s", "save", None, _("save modification times")),
         ("r", "restore", None, _("restore modification times")),
         ("", "tsconfig", "", _("use an alternate configuration file"), _("TSCONFIG")),
+        ("", "amend", None, _("amend the current database file instead of generating a fresh one")),
     ],
     _("hg timestamps [OPTION]..."))
 def timestamps(ui, repo, **opts):
@@ -127,30 +128,38 @@
     if not ctx:
         raise error.Abort(_("no Mercurial working directory"))
     if opts.get("save"):
-        save_timestamps(ui, repo, ctx, tsconfig=opts.get("tsconfig"))
+        save_timestamps(ui,
+                        repo,
+                        ctx,
+                        tsconfig=opts.get("tsconfig"),
+                        amend=opts.get("amend"))
     elif opts.get("restore"):
         restore_timestamps(ui, repo, ctx, tsconfig=opts.get("tsconfig"))
     else:
         raise error.Abort(_("must give a command: --save or --restore"))
 
 
-def save_timestamps(ui, repo, ctx,
-                    tsconfig=None):
+def save_timestamps(ui,
+                    repo,
+                    ctx,
+                    tsconfig=None,
+                    amend=False):
     if not repo.local():
         raise error.Abort(_("repository is not local"))
     matcher = gen_matcher(repo, ctx, tsconfig=tsconfig)
-    with io.open(repo.wjoin(TIMESTAMPS_DATABASE), "wb") as db:
-        with FloatTimesInStat():
-            dbwriter = db_writer(db)
-            dbwriter.send(None)   # prime the coroutine
-            dbwriter.send(("version=1",))
-            dbwriter.send(("encoding=binary",))
-            for fn in ctx:
-                if matcher(fn):
-                    st = os.lstat(repo.wjoin(fn))
-                    print (fn, st, st.st_mtime, to_isoformat(st.st_mtime))
-                    dbwriter.send((fn, to_isoformat(st.st_mtime)))
-            dbwriter.close()
+    if amend:
+        ts = Timestamps.from_filename(repo.wjoin(TIMESTAMPS_DATABASE))
+    else:
+        ts = Timestamps()
+        ts.version = 1
+        ts.encoding = "binary"
+    with FloatTimesInStat():
+        for fn in ctx:
+            if matcher(fn):
+                st = os.lstat(repo.wjoin(fn))
+                ts[fn] = to_isoformat(st.st_mtime)
+        with io.open(repo.wjoin(TIMESTAMPS_DATABASE), "wb") as f:
+            ts.write(f)
 
 
 def restore_timestamps(ui, repo, ctx,
@@ -159,9 +168,7 @@
         raise error.Abort(_("repository is not local"))
     cmdutil.bailifchanged(repo)
     matcher = gen_matcher(repo, ctx, tsconfig=tsconfig)
-    ts = Timestamps()
-    with io.open(repo.wjoin(TIMESTAMPS_DATABASE), "rb") as db:
-        ts.read(db)
+    ts = Timestamps.from_filename(repo.wjoin(TIMESTAMPS_DATABASE))
 
 
 def gen_matcher(repo, ctx, tsconfig=None):
@@ -351,6 +358,27 @@
         return dt.isoformat()
 
 
+def dt_from_isoformat(v):
+    """Parse an full ISO timestamp string into a :class:`datetime.datetime`.
+
+    """
+    mo = TIMESTAMP_FORMAT.search(v)
+    if mo:
+        dtparts = [mo.group("year"),
+                   mo.group("month"),
+                   mo.group("day"),
+                   mo.group("hour"),
+                   mo.group("minute"),
+                   mo.group("second")]
+        try:
+            dtparts.append(mo.group("ms"))
+        except LookupError:
+            pass   # no milliseconds
+        return datetime.datetime(*[int(d, 10) for d in dtparts])
+    else:
+        raise ValueError("invalid timestamp format in line %d" % lineno)
+
+
 class FloatTimesInStat(object):
     """Context manager to ensure that stat returns float values.
 
@@ -386,9 +414,18 @@
     """
 
     def __init__(self):
-        self._d = None
-        self._version = None
-        self._encoding = None
+        self._init()
+
+    def _init(self):
+        self._d = collections.OrderedDict()
+        self._version = self._encoding = None
+
+    @classmethod
+    def from_filename(cls_, filename):
+        with io.open(filename, "rb") as f:
+            ts = cls_()
+            ts.read(f)
+            return ts
 
     @property
     def version(self):
@@ -412,7 +449,7 @@
         :param db: an opened binary file ready for reading
 
         """
-        self._d = collections.OrderedDict()
+        self._init()
         lineno = datano = 0
         for record in db_reader(db):
             lineno += 1
@@ -423,24 +460,8 @@
                 if k.startswith("/"):
                     raise ValueError(
                         "invalid absolute path in line %d" % lineno)
-                mo = TIMESTAMP_FORMAT.search(v)
-                if mo:
-                    dtparts = [mo.group("year"),
-                               mo.group("month"),
-                               mo.group("day"),
-                               mo.group("hour"),
-                               mo.group("minute"),
-                               mo.group("second")]
-                    try:
-                        dtparts.append(mo.group("ms"))
-                    except LookupError:
-                        pass   # no milliseconds
-                    self._d[k] = datetime.datetime(
-                    *[int(d, 10) for d in dtparts])
-                    datano += 1
-                else:
-                    raise ValueError("invalid timestamp format in line %d"
-                                     % lineno)
+                self._d[k] = v
+                datano += 1
             elif not record:
                 self._d["/-%d/" % lineno] = None
             elif len(record) == 1:
@@ -475,6 +496,7 @@
     def write(self, db):
         """Write the internal representation into the file `db`"""
 
+        assert self._d is not None
         assert self._version == 1
 
         dbwriter = db_writer(db)
@@ -507,5 +529,14 @@
                 else:
                     raise ValueError("unknown key type in timestamps: %r" % k)
             else:
-                dbwriter.send((k, v.isoformat() + "Z"))
+                dbwriter.send((k, v))
         dbwriter.close()
+
+    def __setitem__(self, key, value):
+        self._d[key] = value
+
+    def __getitem__(self, key):
+        return self._d[key]
+
+    def __contains__(self, key):
+        return key in self._d