# -*- coding: utf-8 -*-
# :-
# :Copyright: (c) 2020-2025 Franz Glasner
# :License:   BSD-3-Clause
# :-
r"""Generate and verify checksums for directory trees.

"""

from __future__ import print_function, absolute_import


__all__ = []


import argparse
import base64
import binascii
import datetime
import os
import sys
import time

from . import (__version__, __revision__)
from . import util
from .util import cm
from .util import digest
from .util import walk


def main(argv=None):
    aparser = argparse.ArgumentParser(
        description="Generate checksums for directory trees",
        fromfile_prefix_chars='@')
    aparser.add_argument(
        "--algorithm", "-a", action="store", type=util.argv2algo,
        help="1 (aka sha1), 224, 256, 384, 512, "
             "3 (alias for sha3-512), 3-224, 3-256, 3-384, 3-512, "
             "blake2b, blake2b-256 (default), blake2s, "
             "blake2 (alias for blake2b), blake2-256 (alias for blake2b-256), "
             "md5")
    aparser.add_argument(
        "--append-output", action="store_true", dest="append_output",
        help="Append to the output file instead of overwriting it.")
    aparser.add_argument(
        "--base64", action="store_true",
        help="Output checksums in base64 notation, not hexadecimal (OpenBSD).")
    aparser.add_argument(
        "--logical", "-L", dest="logical", action="store_true", default=None,
        help="Follow symbolic links given on command line arguments."
             " Note that this is a different setting as to follow symbolic"
             " links to directories when traversing a directory tree.")
    aparser.add_argument(
        "--mmap", action="store_true", dest="mmap", default=None,
        help="Use mmap if available. Default is to determine automatically "
             "from the filesize.")
    aparser.add_argument(
        "--no-mmap", action="store_false", dest="mmap", default=None,
        help="Dont use mmap.")
    aparser.add_argument(
        "--output", "-o", action="store", metavar="OUTPUT",
        help="Put the checksum into given file. If not given of if it is given"
             " as `-' then stdout is used.")
    aparser.add_argument(
        "--physical", "-P", dest="logical", action="store_false", default=None,
        help="Do not follow symbolic links given on comment line arguments."
             " This is the default.")
    aparser.add_argument(
        "--version", "-v", action="version",
        version="%s (rv:%s)" % (__version__, __revision__))
    aparser.add_argument(
        "directories", nargs="*", metavar="DIRECTORY")

    opts = aparser.parse_args(args=argv)

    if not opts.algorithm:
        opts.algorithm = util.argv2algo("blake2b-256")

    return treesum(opts)


def gen_opts(directories=[],
             algorithm="BLAKE2b-256",
             append_output=False,
             base64=False,
             logical=None,
             mmap=None,
             output=None):
    opts = argparse.Namespace(directories=directories,
                              algorithm=(util.algotag2algotype(algorithm),
                                         algorithm),
                              append_output=append_output,
                              base64=base64,
                              logical=logical,
                              mmap=mmap,
                              output=output)
    return opts


def treesum(opts):
    # XXX TBD: opts.check and opts.checklist (as in shasum.py)
    return generate_treesum(opts)


def generate_treesum(opts):
    if not opts.directories:
        opts.directories.append(".")

    if opts.output is None or opts.output == "-":
        if hasattr(sys.stdout, "buffer"):
            out_cm = cm.nullcontext(sys.stdout.buffer)
        else:
            out_cm = cm.nullcontext(sys.stdout)
    else:
        if opts.append_output:
            out_cm = open(opts.output, "ab")
        else:
            out_cm = open(opts.output, "wb")

    with out_cm as outfp:
        for d in opts.directories:
            generate_treesum_for_directory(
                outfp, d, opts.algorithm, opts.mmap, opts.base64, opts.logical)


def generate_treesum_for_directory(
        outfp, root, algorithm, use_mmap, use_base64, handle_root_logical):
    """

    :param outfp: a *binary* file with a "write()" and a "flush()" method

    """
    outfp.write(format_bsd_line("ROOT", None, root, False))
    outfp.flush()

    # Note given non-default flags that are relevant for directory traversal
    flags = []
    if handle_root_logical:
        flags.append("logical")
    if flags:
        outfp.write(format_bsd_line("FLAGS", ",".join(flags), None, False))

    # Write execution timestamps in POSIX epoch and ISO format
    ts = time.time()
    outfp.write(format_bsd_line("TIMESTAMP", ts, None, False))
    ts = (datetime.datetime.utcfromtimestamp(ts)).isoformat("T")
    outfp.write(format_bsd_line("ISOTIMESTAMP", ts, None, False))

    dir_digests = {}

    if not handle_root_logical and os.path.islink(root):
        linktgt = util.fsencode(os.readlink(root))
        linkdgst = algorithm[0]()
        linkdgst.update(linktgt)
        dir_dgst = algorithm[0]()
        dir_dgst.update(b"1:S,3:./@,")
        dir_dgst.update(linkdgst.digest())
        outfp.write(format_bsd_line(
            algorithm[1],
            dir_dgst.digest(),
            "./@",
            use_base64))
        outfp.flush()
        return

    for top, dirs, nondirs in walk.walk(root, follow_symlinks=False):
        dir_dgst = algorithm[0]()
        for dn in dirs:
            if dn.is_symlink:
                linktgt = util.fsencode(os.readlink(dn.path))
                linkdgst = algorithm[0]()
                linkdgst.update(linktgt)
                dir_dgst.update(b"1:S,%d:%s," % (len(dn.fsname), dn.fsname))
                dir_dgst.update(linkdgst.digest())
                opath = "/".join(top) + "/" + dn.name if top else dn.name
                outfp.write(
                    format_bsd_line(
                        algorithm[1],
                        linkdgst.digest(),
                        "%s/./@" % (opath,),
                        use_base64))
                outfp.flush()
            else:
                # fetch from dir_digests
                dgst = dir_digests[top + (dn.name,)]
                dir_dgst.update(b"1:d,%d:%s," % (len(dn.fsname), dn.fsname))
                dir_dgst.update(dgst)
        for fn in nondirs:
            dir_dgst.update(b"1:f,%d:%s," % (len(fn.fsname), fn.fsname))
            dgst = digest.compute_digest_file(
                algorithm[0], fn.path, use_mmap=use_mmap)
            dir_dgst.update(dgst)
            opath = "/".join(top) + "/" + fn.name if top else fn.name
            outfp.write(format_bsd_line(
                algorithm[1], dgst, opath, use_base64))
            outfp.flush()
        opath = "/".join(top) + "/" if top else ""
        outfp.write(format_bsd_line(
            algorithm[1], dir_dgst.digest(), opath, use_base64))
        outfp.flush()
        dir_digests[top] = dir_dgst.digest()


def format_bsd_line(digestname, value, filename, use_base64):
    ls = os.linesep if isinstance(os.linesep, bytes) \
        else os.linesep.encode("utf-8")
    if not isinstance(digestname, bytes):
        digestname = digestname.encode("ascii")
    if digestname == b"TIMESTAMP":
        assert filename is None
        return b"TIMESTAMP = %d%s" % (value, ls)
    if digestname in (b"ISOTIMESTAMP", b"FLAGS"):
        assert filename is None
        if not isinstance(value, bytes):
            value = value.encode("ascii")
        return b"%s = %s%s" % (digestname, value, ls)
    assert filename is not None
    if not isinstance(filename, bytes):
        filename = util.fsencode(filename)
    if value is None:
        return b"%s (%s)%s" % (digestname, filename, ls)
    if use_base64:
        value = base64.b64encode(value)
    else:
        value = binascii.hexlify(value)
    if filename != b"./@":
        filename = util.normalize_filename(filename, True)
    return b"%s (%s) = %s%s" % (digestname, filename, value, ls)


if __name__ == "__main__":
    sys.exit(main())
