view cutils/shasum.py @ 72:ae2df602beb4

Make shasum.py and dos2unix sub-modules to the new "cutils" package
author Franz Glasner <fzglas.hg@dom66.de>
date Sat, 26 Feb 2022 19:20:20 +0100
parents shasum.py@29fb33aa639a
children c3268f4e752f
line wrap: on
line source

r"""
:Author:    Franz Glasner
:Copyright: (c) 2020-2022 Franz Glasner.
            All rights reserved.
:License:   BSD 3-Clause "New" or "Revised" License.
            See :ref:`LICENSE <license>` for details.
            If you cannot find LICENSE see
            <https://opensource.org/licenses/BSD-3-Clause>
:ID:        @(#) $HGid$

"""

from __future__ import print_function, absolute_import

from . import (__version__, __revision__, __date__)

import argparse
import base64
import binascii
import errno
import hashlib
try:
    from hmac import compare_digest
except ImportError:
    compare_digest = None
import io
try:
    import mmap
except ImportError:
    mmap = None
import os
try:
    import pathlib
except ImportError:
    pathlib = None
import re
import stat
import sys


PY2 = sys.version_info[0] < 3

if PY2:
    PATH_TYPES = (unicode, str)    # noqa: F821 (undefined name 'unicode')
else:
    if pathlib:
        PATH_TYPES = (str, bytes, pathlib.Path)
    else:
        PATH_TYPES = (str, bytes)

CHUNK_SIZE = 1024*1024
MAP_CHUNK_SIZE = 64*1024*1024


def main(argv=None):
    aparser = argparse.ArgumentParser(
        description="Python implementation of shasum",
        fromfile_prefix_chars='@')
    aparser.add_argument(
        "--algorithm", "-a", action="store", type=argv2algo,
        help="1 (default), 224, 256, 384, 512, 3-224, 3-256, 3-384, 3-512, blake2b, blake2s, md5")
    aparser.add_argument(
        "--base64", action="store_true",
        help="Output checksums in base64 notation, not hexadecimal (OpenBSD).")
    aparser.add_argument(
        "--binary", "-b", action="store_false", dest="text_mode", default=False,
        help="Read in binary mode (default)")
    aparser.add_argument(
        "--bsd", "-B", action="store_true", dest="bsd", default=False,
        help="Write BSD style output. This is also the default output format of :command:`openssl dgst`.")
    aparser.add_argument(
        "--check", "-c", action="store_true",
        help="""Read digests from FILEs and check them.
If this option is specified, the FILE options become checklists. Each
checklist should contain hash results in a supported format, which will
be verified against the specified paths. Output consists of the digest
used, the file name, and an OK, FAILED, or MISSING for the result of
the comparison. This will validate any of the supported checksums.
If no file is given, stdin is used.""")
    aparser.add_argument(
        "--checklist", "-C", metavar="CHECKLIST",
        help="""Compare the checksum of each FILE against the checksums in
the CHECKLIST. Any specified FILE that is not listed in the CHECKLIST will
generate an error.""")

    aparser.add_argument(
        "--reverse", "-r", action="store_false", dest="bsd", default=False,
        help="Explicitely select normal coreutils style output (to be option compatible with BSD style commands and :command:`openssl dgst -r`)")
    aparser.add_argument(
        "--tag", action="store_true", dest="bsd", default=False,
        help="Alias for the `--bsd' option (to be compatible with :command:`b2sum`)")
    aparser.add_argument(
        "--text", "-t", action="store_true", dest="text_mode", default=False,
        help="Read in text mode (not supported)")
    aparser.add_argument(
        "--version", "-v", action="version", version="%s (rv:%s)" % (__version__, __revision__))
    aparser.add_argument(
        "files", nargs="*", metavar="FILE")

    opts = aparser.parse_args(args=argv)

    if opts.text_mode:
        print("ERROR: text mode not supported", file=sys.stderr)
        sys.exit(78)   # :manpage:`sysexits(3)`  EX_CONFIG

    if opts.check and opts.checklist:
        print("ERROR: only one of --check or --checklist allowed",
              file=sys.stderr)
        sys.exit(64)   # :manpage:`sysexits(3)`  EX_USAGE

    if not opts.algorithm:
        opts.algorithm = argv2algo("1")

    opts.dest = None

    return shasum(opts)


def gen_opts(files=[], algorithm="SHA1", bsd=False, text_mode=False,
             checklist=False, check=False, dest=None, base64=False):
    if text_mode:
        raise ValueError("text mode not supported")
    if checklist and check:
        raise ValueError("only one of `checklist' or `check' is allowed")
    opts = argparse.Namespace(files=files,
                              algorithm=(algotag2algotype(algorithm),
                                         algorithm),
                              bsd=bsd,
                              checklist=checklist,
                              check=check,
                              text_mode=False,
                              dest=dest,
                              base64=base64)
    return opts


def shasum(opts):
    if opts.check:
        return verify_digests_from_files(opts)
    elif opts.checklist:
        return verify_digests_with_checklist(opts)
    else:
        return generate_digests(opts)


def generate_digests(opts):
    if opts.bsd:
        out = out_bsd
    else:
        out = out_std
    if not opts.files or (len(opts.files) == 1 and opts.files[0] == '-'):
        if PY2:
            if sys.platform == "win32":
                import os, msvcrt   # noqa: E401
                msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)
            source = sys.stdin
        else:
            source = sys.stdin.buffer
        out(sys.stdout,
            compute_digest_stream(opts.algorithm[0], source),
            None,
            opts.algorithm[1],
            True,
            opts.base64)
    else:
        for fn in opts.files:
            out(opts.dest or sys.stdout,
                compute_digest_file(opts.algorithm[0], fn),
                fn,
                opts.algorithm[1],
                True,
                opts.base64)
    return 0


def compare_digests_equal(given_digest, expected_digest, algo):
    """Compare a newly computed binary digest `given_digest` with a digest
    string (hex or base64) in `expected_digest`.

    :param bytes given_digest:
    :param expected_digest: digest (as bytes) or hexlified or base64 encoded
                            digest (as str)
    :type expected_digest: str or bytes or bytearray
    :param algo: The algorithm (factory)
    :return: `True` if the digests are equal, `False` if not
    :rtype: bool

    """
    if isinstance(expected_digest, (bytes, bytearray)) \
       and len(expected_digest) == algo().digest_size:
        exd = expected_digest
    else:
        if len(expected_digest) == algo().digest_size * 2:
            # hex
            if re.search(r"\A[a-fA-F0-9]+\Z", expected_digest):
                try:
                    exd = binascii.unhexlify(expected_digest)
                except TypeError:
                    return False
            else:
                return False
        else:
            # base64
            if re.search(
                    r"\A(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/]{3}=|[A-Za-z0-9+/]{2}==)?\Z",
                    expected_digest):
                try:
                    exd = base64.b64decode(expected_digest)
                except TypeError:
                    return False
            else:
                return False
    if compare_digest:
        return compare_digest(given_digest, exd)
    else:
        return given_digest == exd


def verify_digests_with_checklist(opts):
    dest = opts.dest or sys.stdout
    exit_code = 0
    if not opts.files or (len(opts.files) == 1 and opts.files[0] == '-'):
        if PY2:
            if sys.platform == "win32":
                import os, msvcrt   # noqa: E401
                msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)
            source = sys.stdin
        else:
            source = sys.stdin.buffer
        pl = get_parsed_digest_line_from_checklist(opts.checklist, opts, None)
        if pl is None:
            exit_code = 1
            print("-: MISSING", file=dest)
        else:
            tag, algo, cl_filename, cl_digest = pl
            computed_digest = compute_digest_stream(algo, source)
            if compare_digests_equal(computed_digest, cl_digest, algo):
                res = "OK"
            else:
                res = "FAILED"
                exit_code = 1
            print("{}: {}: {}".format(tag, "-", res), file=dest)
    else:
        for fn in opts.files:
            pl = get_parsed_digest_line_from_checklist(opts.checklist, opts, fn)
            if pl is None:
                print("{}: MISSING".format(fn), file=dest)
                exit_code = 1
            else:
                tag, algo, cl_filename, cl_digest = pl
                computed_digest = compute_digest_file(algo, fn)
                if compare_digests_equal(computed_digest, cl_digest, algo):
                    res = "OK"
                else:
                    exit_code = 1
                    res = "FAILED"
                print("{}: {}: {}".format(tag, fn, res), file=dest)
    return exit_code


def verify_digests_from_files(opts):
    dest = opts.dest or sys.stdout
    exit_code = 0
    if not opts.files or (len(opts.files) == 1 and opts.files[0] == '-'):
        for checkline in sys.stdin:
            if not checkline:
                continue
            r, fn, tag = handle_checkline(opts, checkline)
            print("{}: {}: {}".format(tag, fn, r.upper()), file=dest)
            if r != "ok" and exit_code == 0:
                exit_code = 1
    else:
        for fn in opts.files:
            with io.open(fn, "rt", encoding="utf-8") as checkfile:
                for checkline in checkfile:
                    if not checkline:
                        continue
                    r, fn, tag = handle_checkline(opts, checkline)
                    print("{}: {}: {}".format(tag, fn, r.upper()), file=dest)
                    if r != "ok" and exit_code == 0:
                        exit_code = 1
    return exit_code


def handle_checkline(opts, line):
    """
    :return: a tuple with static "ok", "missing", or "failed", the filename and
             the digest used
    :rtype: tuple(str, str, str)

    """
    parts = parse_digest_line(opts, line)
    if not parts:
        raise ValueError(
            "improperly formatted digest line: {}".format(line))
    tag, algo, fn, digest = parts
    try:
        d = compute_digest_file(algo, fn)
        if compare_digests_equal(d, digest, algo):
            return ("ok", fn, tag)
        else:
            return ("failed", fn, tag)
    except EnvironmentError:
        return ("missing", fn, tag)


def get_parsed_digest_line_from_checklist(checklist, opts, filename):
    if filename is None:
        filenames = ("-", "stdin", "", )
    else:
        filenames = (
            normalize_filename(filename, strip_leading_dot_slash=True),)
    with io.open(checklist, "rt", encoding="utf-8") as clf:
        for checkline in clf:
            if not checkline:
                continue
            parts = parse_digest_line(opts, checkline)
            if not parts:
                raise ValueError(
                    "improperly formatted digest line: {}".format(checkline))
            fn = normalize_filename(parts[2], strip_leading_dot_slash=True)
            if fn in filenames:
                return parts
        else:
            return None


def parse_digest_line(opts, line):
    """Parse a `line` of a digest file and return its parts.

    :return: a tuple of the normalized algorithm tag, the algorithm
             constructor, the filename and the hex digest;
             if `line` cannot be parsed successfully `None` is returned
    :rtype: tuple(str, obj, str, str) or None

    Handles coreutils and BSD-style file formats.

    """
    # determine checkfile format (BSD or coreutils)
    # BSD?
    mo = re.search(r"\A(\S+)\s*\((.*)\)\s*=\s*(.+)\n?\Z", line)
    if mo:
        # (tag, algorithm, filename, digest)
        return (mo.group(1),
                algotag2algotype(mo.group(1)),
                mo.group(2),
                mo.group(3))
    else:
        # coreutils?
        mo = re.search(r"([^\ ]+) [\*\ ]?(.+)\n?\Z", line)
        if mo:
            # (tag, algorithm, filename, digest)
            return (opts.algorithm[1],
                    opts.algorithm[0],
                    mo.group(2),
                    mo.group(1))
        else:
            return None


def argv2algo(s):
    """Convert a command line algorithm specifier into a tuple with the
    type/factory of the digest and the algorithms tag for output purposes.

    :param str s: the specifier from the commane line
    :return: the internal digest specification
    :rtype: a tuple (digest_type_or_factory, name_in_output)

    String comparisons are done case-insensitively.

    """
    s = s.lower()
    if s in ("1", "sha1"):
        return (hashlib.sha1, "SHA1")
    elif s in ("224", "sha224"):
        return (hashlib.sha224, "SHA224")
    elif s in ("256", "sha256"):
        return (hashlib.sha256, "SHA256")
    elif s in ("384", "sha384"):
        return (hashlib.sha384, "SHA384")
    elif s in ("512", "sha512"):
        return (hashlib.sha512, "SHA512")
    elif s in ("3-224", "sha3-224"):
        return (hashlib.sha3_224, "SHA3-224")
    elif s in ("3-256", "sha3-256"):
        return (hashlib.sha3_256, "SHA3-256")
    elif s in ("3-384", "sha3-384"):
        return (hashlib.sha3_384, "SHA3-384")
    elif s in ("3-512", "sha3-512"):
        return (hashlib.sha3_512, "SHA3-512")
    elif s in ("blake2b", "blake2b-512"):
        return (hashlib.blake2b, "BLAKE2b")
    elif s in ("blake2s", "blake2s-256"):
        return (hashlib.blake2s, "BLAKE2s")
    elif s == "md5":
        return (hashlib.md5, "MD5")
    else:
        raise argparse.ArgumentTypeError(
            "`{}' is not a recognized algorithm".format(s))


def algotag2algotype(s):
    """Convert the algorithm specifier in a BSD-style digest file to the
    type/factory of the corresponding algorithm.

    :param str s: the tag (i.e. normalized name) or the algorithm
    :return: the digest type or factory for `s`

    All string comparisons are case-sensitive.

    """
    if s == "SHA1":
        return hashlib.sha1
    elif s == "SHA224":
        return hashlib.sha224
    elif s == "SHA256":
        return hashlib.sha256
    elif s == "SHA384":
        return hashlib.sha384
    elif s == "SHA512":
        return hashlib.sha512
    elif s == "SHA3-224":
        return hashlib.sha3_224
    elif s == "SHA3-256":
        return hashlib.sha3_256
    elif s == "SHA3-384":
        return hashlib.sha3_384
    elif s == "SHA3-512":
        return hashlib.sha3_512
    elif s == "BLAKE2b":
        return hashlib.blake2b
    elif s == "BLAKE2s":
        return hashlib.blake2s
    elif s == "MD5":
        return hashlib.md5
    else:
        raise ValueError("unknown algorithm: {}".format(s))


def out_bsd(dest, digest, filename, digestname, binary, use_base64):
    """BSD format output, also :command:`openssl dgst` and
    :command:`b2sum --tag" format output

    """
    if use_base64:
        digest = base64.b64encode(digest).decode("ascii")
    else:
        digest = binascii.hexlify(digest).decode("ascii")
    if filename is None:
        print(digest, file=dest)
    else:
        print("{} ({}) = {}".format(digestname,
                                    normalize_filename(filename),
                                    digest),
              file=dest)


def out_std(dest, digest, filename, digestname, binary, use_base64):
    """Coreutils format (:command:`shasum` et al.)

    """
    if use_base64:
        digest = base64.b64encode(digest).decode("ascii")
    else:
        digest = binascii.hexlify(digest).decode("ascii")
    print("{} {}{}".format(
                digest,
                '*' if binary else ' ',
                '-' if filename is None else normalize_filename(filename)),
          file=dest)


def compute_digest_file(hashobj, path, use_mmap=True):
    """
    :param hashobj: a :mod:`hashlib` compatible hash algorithm type or factory
    :param path: filename within the filesystem or a file descriptor opened in
                 binary mode (also a socket or pipe)
    :param bool use_mmap: use the :mod:`mmap` module if available
    :return: the digest in binary form
    :rtype: bytes

    If a file descriptor is given is must support :func:`os.read`.

    """
    h = hashobj()
    if isinstance(path, PATH_TYPES):
        flags = os.O_RDONLY | getattr(os, "O_BINARY", 0) \
            | getattr(os, "O_SEQUENTIAL", 0) | getattr(os, "O_NOCTTY", 0)
        fd = os.open(path, flags)
        own_fd = True
    else:
        fd = path
        own_fd = False
    try:
        try:
            st = os.fstat(fd)
        except TypeError:
            #
            # "fd" is most probably a Python socket object.
            # (a pipe typically supports fstat)
            #
            use_mmap = False
        else:
            if stat.S_ISREG(st[stat.ST_MODE]):
                filesize = st[stat.ST_SIZE]
            else:
                use_mmap = False
        if mmap is None or not use_mmap:
            # No mmmap available -> use traditional low-level file IO
            while True:
                try:
                    buf = os.read(fd, CHUNK_SIZE)
                except OSError as e:
                    if e.errno not in (errno.EAGAIN, errno.EWOULDBLOCK):
                        raise
                else:
                    if len(buf) == 0:
                        break
                    h.update(buf)
        else:
            #
            # Use mmap
            #
            # NOTE: On Windows mmapped files with length 0 are not supported.
            #       So ensure to not call mmap.mmap() if the file size is 0.
            #
            madvise = getattr(mmap.mmap, "madvise", None)
            if filesize < MAP_CHUNK_SIZE:
                mapsize = filesize
            else:
                mapsize = MAP_CHUNK_SIZE
            mapoffset = 0
            rest = filesize
            while rest > 0:
                m = mmap.mmap(fd,
                              mapsize,
                              access=mmap.ACCESS_READ,
                              offset=mapoffset)
                if madvise:
                    madvise(m, mmap.MADV_SEQUENTIAL)
                try:
                    h.update(m)
                finally:
                    m.close()
                rest -= mapsize
                mapoffset += mapsize
                if rest < mapsize:
                    mapsize = rest
    finally:
        if own_fd:
            os.close(fd)
    return h.digest()


def compute_digest_stream(hashobj, instream):
    """

    :param hashobj: a :mod:`hashlib` compatible hash algorithm type or factory
    :param instream: a bytes input stream to read the data to be hashed from
    :return: the digest in binary form
    :rtype: bytes

    """
    h = hashobj()
    while True:
        try:
            buf = instream.read(CHUNK_SIZE)
        except OSError as e:
            if e.errno not in (errno.EAGAIN, errno.EWOULDBLOCK):
                raise
        else:
            if buf is not None:
                if len(buf) == 0:
                    break
                h.update(buf)
    return h.digest()


def normalize_filename(filename, strip_leading_dot_slash=False):
    filename = filename.replace("\\", "/")
    if strip_leading_dot_slash:
        while filename.startswith("./"):
            filename = filename[2:]
    return filename


if __name__ == "__main__":
    sys.exit(main())