view pyramid_upgrade.py @ 2:33ab2058c6d9 draft

planemo upload for repository https://github.com/ohsu-comp-bio/ashlar commit 72a33d7c3ad18e717ec61c6b845099d69b9b7abd
author goeckslab
date Tue, 20 Sep 2022 17:33:50 +0000
parents f183d9de4622
children
line wrap: on
line source

import argparse
import dataclasses
import fractions
import io
import os
import re
import reprlib
import struct
import sys
import xml.etree.ElementTree
from typing import Any, List


datatype_formats = {
    1: "B",  # BYTE
    2: "s",  # ASCII
    3: "H",  # SHORT
    4: "I",  # LONG
    5: "I",  # RATIONAL (pairs)
    6: "b",  # SBYTE
    7: "B",  # UNDEFINED
    8: "h",  # SSHORT
    9: "i",  # SLONG
    10: "i",  # SRATIONAL (pairs)
    11: "f",  # FLOAT
    12: "d",  # DOUBLE
    13: "I",  # IFD
    16: "Q",  # LONG8
    17: "q",  # SLONG8
    18: "Q",  # IFD8
}
rational_datatypes = {5, 10}


class TiffSurgeon:
    """Read, manipulate and write IFDs in BigTIFF files."""

    def __init__(self, path, *, writeable=False, encoding=None):
        self.path = path
        self.writeable = writeable
        self.encoding = encoding
        self.endian = ""
        self.ifds = None
        self.file = open(self.path, "r+b" if self.writeable else "rb")
        self._validate()

    def _validate(self):
        signature = self.read("2s")
        signature = signature.decode("ascii", errors="ignore")
        if signature == "II":
            self.endian = "<"
        elif signature == "MM":
            self.endian = ">"
        else:
            raise FormatError(f"Not a TIFF file (signature is '{signature}').")
        version = self.read("H")
        if version == 42:
            raise FormatError("Cannot process classic TIFF, only BigTIFF.")
        offset_size, reserved, first_ifd_offset = self.read("H H Q")
        if version != 43 or offset_size != 8 or reserved != 0:
            raise FormatError("Malformed TIFF, giving up!")
        self.first_ifd_offset = first_ifd_offset

    def read(self, fmt, *, file=None):
        if file is None:
            file = self.file
        endian = self.endian or "="
        size = struct.calcsize(endian + fmt)
        raw = file.read(size)
        value = self.unpack(fmt, raw)
        return value

    def write(self, fmt, *values):
        if not self.writeable:
            raise ValueError("File is opened as read-only.")
        raw = self.pack(fmt, *values)
        self.file.write(raw)

    def unpack(self, fmt, raw):
        assert self.endian or re.match(r"\d+s", fmt), \
            "can't unpack non-string before endianness is detected"
        fmt = self.endian + fmt
        size = struct.calcsize(fmt)
        values = struct.unpack(fmt, raw[:size])
        if len(values) == 1:
            return values[0]
        else:
            return values

    def pack(self, fmt, *values):
        assert self.endian, "can't pack without endian set"
        fmt = self.endian + fmt
        raw = struct.pack(fmt, *values)
        return raw

    def read_ifds(self):
        ifds = [self.read_ifd(self.first_ifd_offset)]
        while ifds[-1].offset_next:
            ifds.append(self.read_ifd(ifds[-1].offset_next))
        self.ifds = ifds

    def read_ifd(self, offset):
        self.file.seek(offset)
        num_tags = self.read("Q")
        buf = io.BytesIO(self.file.read(num_tags * 20))
        offset_next = self.read("Q")
        try:
            tags = TagSet([self.read_tag(buf) for i in range(num_tags)])
        except FormatError as e:
            raise FormatError(f"IFD at offset {offset}, {e}") from None
        ifd = Ifd(tags, offset, offset_next)
        return ifd

    def read_tag(self, buf):
        tag = Tag(*self.read("H H Q 8s", file=buf))
        value, offset_range = self.tag_value(tag)
        tag = dataclasses.replace(tag, value=value, offset_range=offset_range)
        return tag

    def append_ifd_sequence(self, ifds):
        """Write list of IFDs as a chained sequence at the end of the file.

        Returns a list of new Ifd objects with updated offsets.

        """
        self.file.seek(0, os.SEEK_END)
        new_ifds = []
        for ifd in ifds:
            offset = self.file.tell()
            self.write("Q", len(ifd.tags))
            for tag in ifd.tags:
                self.write_tag(tag)
            offset_next = self.file.tell() + 8 if ifd is not ifds[-1] else 0
            self.write("Q", offset_next)
            new_ifd = dataclasses.replace(
                ifd, offset=offset, offset_next=offset_next
            )
            new_ifds.append(new_ifd)
        return new_ifds

    def append_tag_data(self, code, datatype, value):
        """Build new tag and write data to the end of the file if necessary.

        Returns a Tag object corresponding to the passed parameters. This
        function only writes any "overflow" data and not the IFD entry itself,
        so the returned Tag must still be written to an IFD.

        If the value is small enough to fit in the data field within an IFD, no
        data will actually be written to the file and the returned Tag object
        will have the value encoded in its data attribute. Otherwise the data
        will be appended to the file and the returned Tag's data attribute will
        encode the corresponding offset.

        """
        fmt = datatype_formats[datatype]
        # FIXME Should we perform our own check that values match datatype?
        # struct.pack will do it but the exception won't be as understandable.
        original_value = value
        if isinstance(value, str):
            if not self.encoding:
                raise ValueError(
                    "ASCII tag values must be bytes if encoding is not set"
                )
            value = [value.encode(self.encoding) + b"\x00"]
            count = len(value[0])
        elif isinstance(value, bytes):
            value = [value + b"\x00"]
            count = len(value[0])
        else:
            try:
                len(value)
            except TypeError:
                value = [value]
            count = len(value)
        struct_count = count
        if datatype in rational_datatypes:
            value = [i for v in value for i in v.as_integer_ratio()]
            count //= 2
        byte_count = struct_count * struct.calcsize(fmt)
        if byte_count <= 8:
            data = self.pack(str(struct_count) + fmt, *value)
            data += bytes(8 - byte_count)
        else:
            self.file.seek(0, os.SEEK_END)
            data = self.pack("Q", self.file.tell())
            self.write(str(count) + fmt, *value)
        # TODO Compute and set offset_range.
        tag = Tag(code, datatype, count, data, original_value)
        return tag

    def write_first_ifd_offset(self, offset):
        self.file.seek(8)
        self.write("Q", offset)

    def write_tag(self, tag):
        self.write("H H Q 8s", tag.code, tag.datatype, tag.count, tag.data)

    def tag_value(self, tag):
        """Return decoded tag data and the file offset range."""
        fmt = datatype_formats[tag.datatype]
        count = tag.count
        if tag.datatype in rational_datatypes:
            count *= 2
        byte_count = count * struct.calcsize(fmt)
        if byte_count <= 8:
            value = self.unpack(str(count) + fmt, tag.data)
            offset_range = range(0, 0)
        else:
            offset = self.unpack("Q", tag.data)
            self.file.seek(offset)
            value = self.read(str(count) + fmt)
            offset_range = range(offset, offset + byte_count)
        if tag.datatype == 2:
            value = value.rstrip(b"\x00")
            if self.encoding:
                try:
                    value = value.decode(self.encoding)
                except UnicodeDecodeError as e:
                    raise FormatError(f"tag {tag.code}: {e}") from None
        elif tag.datatype in rational_datatypes:
            value = [
                fractions.Fraction(*v) for v in zip(value[::2], value[1::2])
            ]
            if len(value) == 1:
                value = value[0]
        return value, offset_range

    def close(self):
        self.file.close()


@dataclasses.dataclass(frozen=True)
class Tag:
    code: int
    datatype: int
    count: int
    data: bytes
    value: Any = None
    offset_range: range = None

    _vrepr = reprlib.Repr()
    _vrepr.maxstring = 60
    _vrepr.maxother = 60
    vrepr = _vrepr.repr

    def __repr__(self):
        return (
            self.__class__.__qualname__ + "("
            + f"code={self.code!r}, datatype={self.datatype!r}, "
            + f"count={self.count!r}, data={self.data!r}, "
            + f"value={self.vrepr(self.value)}"
            + ")"
        )


@dataclasses.dataclass(frozen=True)
class TagSet:
    """Container for Tag objects as stored in a TIFF IFD.

    Tag objects are maintained in a list that's always sorted in ascending order
    by the tag code. Only one tag for a given code may be present, which is where
    the "set" name comes from.

    """

    tags: List[Tag] = dataclasses.field(default_factory=list)

    def __post_init__(self):
        if len(self.codes) != len(set(self.codes)):
            raise ValueError("Duplicate tag codes are not allowed.")

    def __repr__(self):
        ret = type(self).__name__ + "(["
        if self.tags:
            ret += "\n"
        ret += "".join([f"    {t},\n" for t in self.tags])
        ret += "])"
        return ret

    @property
    def codes(self):
        return [t.code for t in self.tags]

    def __getitem__(self, code):
        for t in self.tags:
            if code == t.code:
                return t
        else:
            raise KeyError(code)

    def __delitem__(self, code):
        try:
            i = self.codes.index(code)
        except ValueError:
            raise KeyError(code) from None
        self.tags[:] = self.tags[:i] + self.tags[i + 1:]

    def __contains__(self, code):
        return code in self.codes

    def __len__(self):
        return len(self.tags)

    def __iter__(self):
        return iter(self.tags)

    def get(self, code, default=None):
        try:
            return self[code]
        except KeyError:
            return default

    def get_value(self, code, default=None):
        tag = self.get(code)
        if tag:
            return tag.value
        else:
            return default

    def insert(self, tag):
        """Add a new tag or replace an existing one."""
        for i, t in enumerate(self.tags):
            if tag.code == t.code:
                self.tags[i] = tag
                return
            elif tag.code < t.code:
                break
        else:
            i = len(self.tags)
        n = len(self.tags)
        self.tags[i:n + 1] = [tag] + self.tags[i:n]


@dataclasses.dataclass(frozen=True)
class Ifd:
    tags: TagSet
    offset: int
    offset_next: int

    @property
    def nbytes(self):
        return len(self.tags) * 20 + 16

    @property
    def offset_range(self):
        return range(self.offset, self.offset + self.nbytes)


class FormatError(Exception):
    pass


def fix_attrib_namespace(elt):
    """Prefix un-namespaced XML attributes with the tag's namespace."""
    # This fixes ElementTree's inability to round-trip XML with a default
    # namespace ("cannot use non-qualified names with default_namespace option"
    # error). 7-year-old BPO issue here: https://bugs.python.org/issue17088
    # Code inspired by https://gist.github.com/provegard/1381912 .
    if elt.tag[0] == "{":
        uri, _ = elt.tag[1:].rsplit("}", 1)
        new_attrib = {}
        for name, value in elt.attrib.items():
            if name[0] != "{":
                # For un-namespaced attributes, copy namespace from element.
                name = f"{{{uri}}}{name}"
            new_attrib[name] = value
        elt.attrib = new_attrib
    for child in elt:
        fix_attrib_namespace(child)


def parse_args():
    parser = argparse.ArgumentParser(
        description="Convert an OME-TIFF legacy pyramid to the BioFormats 6"
                    " OME-TIFF pyramid format in-place.",
    )
    parser.add_argument("image", help="OME-TIFF file to convert")
    parser.add_argument(
        "-n",
        dest="channel_names",
        nargs="+",
        default=[],
        metavar="NAME",
        help="Channel names to be inserted into OME metadata. Number of names"
             " must match number of channels in image. Be sure to put quotes"
             " around names containing spaces or other special shell characters."
    )
    args = parser.parse_args()
    return args


def main():

    args = parse_args()

    image_path = sys.argv[1]
    try:
        tiff = TiffSurgeon(image_path, encoding="utf-8", writeable=True)
    except FormatError as e:
        print(f"TIFF format error: {e}")
        sys.exit(1)

    tiff.read_ifds()

    # ElementTree doesn't parse xml declarations so we'll just run some sanity
    # checks that we do have UTF-8 and give it a decoded string instead of raw
    # bytes. We need to both ensure that the raw tag bytes decode properly and
    # that the declaration encoding is UTF-8 if present.
    try:
        omexml = tiff.ifds[0].tags.get_value(270, "")
    except FormatError:
        print("ImageDescription tag is not a valid UTF-8 string (not an OME-TIFF?)")
        sys.exit(1)
    if re.match(r'<\?xml [^>]*encoding="(?!UTF-8)[^"]*"', omexml):
        print("OME-XML is encoded with something other than UTF-8.")
        sys.exit(1)

    xml_ns = {"ome": "http://www.openmicroscopy.org/Schemas/OME/2016-06"}

    if xml_ns["ome"] not in omexml:
        print("Not an OME-TIFF.")
        sys.exit(1)
    if (
        "Faas" not in tiff.ifds[0].tags.get_value(305, "")
        or 330 in tiff.ifds[0].tags
    ):
        print("Not a legacy OME-TIFF pyramid.")
        sys.exit(1)

    # All XML manipulation assumes the document is valid OME-XML!
    root = xml.etree.ElementTree.fromstring(omexml)
    image = root.find("ome:Image", xml_ns)
    pixels = image.find("ome:Pixels", xml_ns)
    size_x = int(pixels.get("SizeX"))
    size_y = int(pixels.get("SizeY"))
    size_c = int(pixels.get("SizeC"))
    size_z = int(pixels.get("SizeZ"))
    size_t = int(pixels.get("SizeT"))
    num_levels = len(root.findall("ome:Image", xml_ns))
    page_dims = [(ifd.tags[256].value, ifd.tags[257].value) for ifd in tiff.ifds]

    if len(root) != num_levels:
        print("Top-level OME-XML elements other than Image are not supported.")
    if size_z != 1 or size_t != 1:
        print("Z-stacks and multiple timepoints are not supported.")
        sys.exit(1)
    if size_c * num_levels != len(tiff.ifds):
        print("TIFF page count does not match OME-XML Image elements.")
        sys.exit(1)
    if any(dims != (size_x, size_y) for dims in page_dims[:size_c]):
        print(f"TIFF does not begin with SizeC={size_c} full-size pages.")
        sys.exit(1)
    for level in range(1, num_levels):
        level_dims = page_dims[level * size_c: (level + 1) * size_c]
        if len(set(level_dims)) != 1:
            print(
                f"Pyramid level {level + 1} out of {num_levels} has inconsistent"
                f" sizes:\n{level_dims}"
            )
            sys.exit(1)
    if args.channel_names and len(args.channel_names) != size_c:
        print(
            f"Wrong number of channel names -- image has {size_c} channels but"
            f" {len(args.channel_names)} names were specified:"
        )
        for i, n in enumerate(args.channel_names, 1):
            print(f"{i:4}: {n}")
        sys.exit(1)

    print("Input image summary")
    print("===================")
    print(f"Dimensions: {size_x} x {size_y}")
    print(f"Number of channels: {size_c}")
    print(f"Pyramid sub-resolutions ({num_levels - 1} total):")
    for dim_x, dim_y in page_dims[size_c::size_c]:
        print(f"    {dim_x} x {dim_y}")
    software = tiff.ifds[0].tags.get_value(305, "<not set>")
    print(f"Software: {software}")
    print()

    print("Updating OME-XML metadata...")
    # We already verified there is nothing but Image elements under the root.
    for other_image in root[1:]:
        root.remove(other_image)
    for tiffdata in pixels.findall("ome:TiffData", xml_ns):
        pixels.remove(tiffdata)
    new_tiffdata = xml.etree.ElementTree.Element(
        f"{{{xml_ns['ome']}}}TiffData",
        attrib={"IFD": "0", "PlaneCount": str(size_c)},
    )
    # A valid OME-XML Pixels begins with size_c Channels; then comes TiffData.
    pixels.insert(size_c, new_tiffdata)

    if args.channel_names:
        print("Renaming channels...")
        channels = pixels.findall("ome:Channel", xml_ns)
        for channel, name in zip(channels, args.channel_names):
            channel.attrib["Name"] = name

    fix_attrib_namespace(root)
    # ElementTree.tostring would have been simpler but it only supports
    # xml_declaration and default_namespace starting with Python 3.8.
    xml_file = io.BytesIO()
    tree = xml.etree.ElementTree.ElementTree(root)
    tree.write(
        xml_file,
        encoding="utf-8",
        xml_declaration=True,
        default_namespace=xml_ns["ome"],
    )
    new_omexml = xml_file.getvalue()

    print("Writing new TIFF headers...")
    stale_ranges = [ifd.offset_range for ifd in tiff.ifds]
    main_ifds = tiff.ifds[:size_c]
    channel_sub_ifds = [tiff.ifds[c + size_c::size_c] for c in range(size_c)]
    for i, (main_ifd, sub_ifds) in enumerate(zip(main_ifds, channel_sub_ifds)):
        for ifd in sub_ifds:
            if 305 in ifd.tags:
                stale_ranges.append(ifd.tags[305].offset_range)
                del ifd.tags[305]
            ifd.tags.insert(tiff.append_tag_data(254, 3, 1))
        if i == 0:
            stale_ranges.append(main_ifd.tags[305].offset_range)
            stale_ranges.append(main_ifd.tags[270].offset_range)
            old_software = main_ifd.tags[305].value.replace("Faas", "F*a*a*s")
            new_software = f"pyramid_upgrade.py (was {old_software})"
            main_ifd.tags.insert(tiff.append_tag_data(305, 2, new_software))
            main_ifd.tags.insert(tiff.append_tag_data(270, 2, new_omexml))
        else:
            if 305 in main_ifd.tags:
                stale_ranges.append(main_ifd.tags[305].offset_range)
                del main_ifd.tags[305]
        sub_ifds[:] = tiff.append_ifd_sequence(sub_ifds)
        offsets = [ifd.offset for ifd in sub_ifds]
        main_ifd.tags.insert(tiff.append_tag_data(330, 16, offsets))
    main_ifds = tiff.append_ifd_sequence(main_ifds)
    tiff.write_first_ifd_offset(main_ifds[0].offset)

    print("Clearing old headers and tag values...")
    # We overwrite all the old IFDs and referenced data values with obvious
    # "filler" as a courtesy to anyone who might need to poke around in the TIFF
    # structure down the road. A real TIFF parser wouldn't see the stale data,
    # but a human might just scan for the first thing that looks like a run of
    # OME-XML and not realize it's been replaced with something else. The filler
    # content is the repeated string "unused " with square brackets at the
    # beginning and end of each filled IFD or data value.
    filler = b"unused "
    f_len = len(filler)
    for r in stale_ranges:
        tiff.file.seek(r.start)
        tiff.file.write(b"[")
        f_total = len(r) - 2
        for i in range(f_total // f_len):
            tiff.file.write(filler)
        tiff.file.write(b" " * (f_total % f_len))
        tiff.file.write(b"]")

    tiff.close()

    print()
    print("Success!")


if __name__ == "__main__":
    main()