from vsi.tools import Try
import numpy as np
import os
[docs]class Register(object):
def __init__(self):
self.readers = []
[docs] def register(self, reader):
if reader not in self.readers:
self.readers.append(reader)
registered_readers=Register()
registered_writers=Register()
# @@@ Generic classes @@@
[docs]class Reader(object):
def __init__(self, filename, autoload=False, *args, **kwargs):
self.filename = filename
if autoload:
self.load(*args, **kwargs)
Reader.extensions = None #Default all
[docs]class Writer(object):
def __init__(self, array, dtype=None):
if dtype:
self.dtype = dtype
self.array = np.asarray(array, self.dtype)
else:
self.array = array
self.dtype = array.dtype
# @@@ tifffile classes @@@
with Try(ImportError):
import tifffile
class TifffileReader(Reader):
#Assume series 0, this if for multiple files being treated as a series
#Not currently considered
def load(self, **kwargs):
self.object = tifffile.TiffFile(self.filename, **kwargs)
def raster(self, segment=0,**kwargs):
return self.object.asarray(key=segment, **kwargs)
def shape(self, segment=0):
return tuple(self.object.pages[segment].shape)
def dtype(self, segment=0):
return self.object.series[0]['dtype']
def bpp(self, segment=0):
return self.dtype(segment).itemsize*8
def band_names(self):
raise Exception('Unimplemented. Used PilReader')
def endian(self):
return self.object.byteorder
def bands(self, segment=0):
if len(self.object.pages[segment].shape)>2:
return self.object.pages[segment].shape[2]
else:
return 1
@property
def segments(self):
return len(self.object.pages)
TifffileReader.extensions=['tif', 'tiff']
registered_readers.register(TifffileReader)
#Monkey patching to add JPEG compress TIFF support via PIL
with Try(ImportError):
from PIL import Image
def decode_jpeg(encoded, tables=b'', photometric=None,
ycbcr_subsampling=None, ycbcr_positioning=None):
''' ycbcr resampling is missing in both tifffile and PIL '''
from StringIO import StringIO
from PIL import JpegImagePlugin
return JpegImagePlugin.JpegImageFile(StringIO(tables + encoded)).tobytes()
tifffile.TIFF_DECOMPESSORS['jpeg'] = decode_jpeg
tifffile.decodejpg = decode_jpeg
class TifffileWriter(Writer):
def save(self, filename, **kwargs):
tifargs = {}
for key in ('byteorder', 'bigtiff', 'software', 'writeshape'):
if key in kwargs:
tifargs[key] = kwargs[key]
del kwargs[key]
if 'writeshape' not in kwargs:
kwargs['writeshape'] = True
if 'bigtiff' not in tifargs and self.array.size*self.array.dtype.itemsize > 2000*2**20:
tifargs['bigtiff'] = True
self.object = tifffile.TiffWriter(filename, **tifargs)
self.object.save(self.array, **kwargs)
TifffileWriter.extensions=['tif', 'tiff']
registered_writers.register(TifffileWriter)
# @@@ PIL classes @@@
with Try(ImportError):
from PIL import Image
class PilReader(Reader):
def load(self, mode='r'):
self.object = Image.open(self.filename, mode)
def _get_mode_info(self, segment=0):
''' TODO: Cache this if it's being called a lot and slowing down'''
self.object.seek(segment)
mode = Image._MODE_CONV[self.object.mode]
mode = {'endian':mode[0][0],
'type':mode[0][1],
'bpp':int(mode[0][2])*8,
'bands':mode[1]}
if mode['type'] == 'b':
mode['bpp'] = 1
if mode['bands'] is None:
mode['bands'] = 1
if mode['type'] == 'b':
mode['type'] = np.bool
elif mode['type'] == 'u':
mode['type'] = getattr(np, 'uint%d' % mode['bpp'])
elif mode['type'] == 'i':
mode['type'] = getattr(np, 'int%d' % mode['bpp'])
elif mode['type'] == 'f':
mode['type'] = getattr(np, 'float%d' % mode['bpp'])
else:
raise Exception('Unknown mode type')
return mode
def endian(self, segment=0):
return self._get_mode_info(segment)['endian']
def raster(self, segment=0):
self.object.seek(segment)
return np.array(self.object)
def bpp(self, segment=0):
self.object.seek(segment)
return self._get_mode_info()['bpp']
def dtype(self, segment=0):
self.object.seek(segment)
return self._get_mode_info()['type']
def bands(self, segment=0):
self.object.seek(segment)
return self._get_mode_info()['bands']
def band_names(self, segment=0):
self.object.seek(segment)
try:
band_names = Image._MODEINFO[self.object.mode][2]
if len(band_names) == 1:
return ('P',)
return band_names
except KeyError:
return ('P',) #Panchromatic
def shape(self, segment=0):
#shape is height, width, bands
self.object.seek(segment)
shape = self.object.size
return (shape[1], shape[0])+shape[2:]
registered_readers.register(PilReader)
class PilWriter(Writer):
def __init__(self, array, dtype=None, *args, **kwargs):
super(PilWriter, self).__init__(array, dtype)
self.object = Image.fromarray(self.array, *args, **kwargs)
def save(self, filename, *args, **kwargs):
self.object.save(filename, *args, **kwargs)
registered_writers.register(PilWriter)
# @@@ GDAL classes @@@
with Try(ImportError):
from osgeo import gdal
[docs] class GdalReader(Reader):
def __init__(self, *args, **kwargs):
#default
self._segment = 0
super(GdalReader, self).__init__(*args, **kwargs)
def _change_segment(self, segment=0, mode=gdal.GA_ReadOnly):
if segment != self._segment:
self._dataset = gdal.Open(self.object.GetSubDatasets()[segment][0], mode)
self._segment = segment
[docs] def load(self, mode=gdal.GA_ReadOnly, *args, **kwargs):
self.object = gdal.Open(self.filename, mode, *args, **kwargs)
if self.object is None:
raise Exception('Gdal can not determine driver')
self._dataset = self.object
[docs] def raster(self, segment=0, *args, **kwargs):
#return self.object.GetRasterBand(band).ReadAsArray()
self._change_segment(segment)
raster = self._dataset.ReadAsArray()
if len(raster.shape)==3:
return raster.transpose((1,2,0))
else:
return raster
[docs] def raster_roi(self, segment=0, *args, **kwargs):
'''This isn't written yet'''
self._change_segment(segment)
band = self.object.GetRasterBand(band)
scanline = band.ReadRaster( 0, 0, band.XSize, 1, \
band.XSize, 1, GDT_Float32 )
import struct
tuple_of_floats = struct.unpack('f' * b2.XSize, scanline)
#(use a numpy array instead of unpack)
[docs] def bands(self, segment=0):
self._change_segment(segment)
return self._dataset.RasterCount
[docs] def shape(self, segment=0):
self._change_segment(segment)
if self.object.RasterCount > 1:
return (self.object.RasterYSize, self.object.RasterXSize, self.object.RasterCount)
else:
return (self.object.RasterYSize, self.object.RasterXSize)
#def saveas(self, filename, strict=False): THIS IS CRAP
# ''' Copy the current object and save it to disk as a different file '''
#
# destination = self.object.GetDriver().CreateCopy(filename, self.object, strict)
#There is a LOT unimplemented here. I do NOT know GDAL enough to fill in the gaps
registered_readers.register(GdalReader)
from osgeo.gdal_array import codes as gdal_codes
[docs] class GdalWriter(Writer):
gdal_array_types = {np.dtype(v):k for k,v in gdal_codes.iteritems()}
[docs] def save(self, filename, driver=None, *args, **kwargs):
if driver is None:
ext = os.path.splitext(filename)[1][1:]
if ext.lower() in ['tif', 'tiff']:
driver = gdal.GetDriverByName('GTiff')
else:
raise Exception('Unkown extension. Can not determine driver')
bands = self.array.shape[2] if len(self.array.shape)>2 else 1
self.object = driver.Create(filename, self.array.shape[1],
self.array.shape[0], bands, GdalWriter.gdal_array_types[np.dtype(self.dtype)])
if bands==1:
self.object.GetRasterBand(1).WriteArray(self.array)
else:
for band in range(bands):
self.object.GetRasterBand(band+1).WriteArray(self.array[:,:,band])
#del self.object
#Need to be deleted to actually save
# dst_ds.SetGeoTransform( [ 444720, 30, 0, 3751320, 0, -30 ] )
# srs = osr.SpatialReference()
# srs.SetUTM( 11, 1 )
# srs.SetWellKnownGeogCS( 'NAD27' )
# dst_ds.SetProjection( srs.ExportToWkt() )
# @@@ Common feel functions @@@
[docs]def imread(filename, *args, **kwargs):
extension = os.path.splitext(filename)[1][1:]
for reader in registered_readers.readers:
if not reader.extensions or extension in reader.extensions:
try:
return reader(filename, autoload=True)
except:
pass
return None
[docs]def imwrite(img, filename, *args, **kwargs):
""" write the numpy array as an image """
_, ext = os.path.splitext(filename)
is_multiplane = len(img.shape) > 2
if has_tifffile and (ext == '.tiff' or ext == '.tif') and is_multiplane:
# if image is tiff, use tifffile module
tifffile.imsave(filename, img)
else:
pilImg = Image.fromarray(img)
if pilImg.mode == 'L':
pilImg.convert('I') # convert to 32 bit signed mode
pilImg.save(filename)
return
[docs]def imwrite_geotiff(img, filename, transform, wkt_projection=None):
if wkt_projection == None:
import osr
projection = osr.SpatialReference()
projection.SetWellKnownGeogCS('WGS84')
wkt_projection = projection.ExportToWkt()
gdal_writer = GdalWriter(img)
gdal_writer.save(filename)
gdal_writer.object.SetGeoTransform(transform)
gdal_writer.object.SetProjection(wkt_projection)
[docs]def imwrite_byte(img, vmin, vmax, filename):
""" write the 2-d numpy array as an image, scale to byte range first """
img_byte = np.uint8(np.zeros_like(img))
img_norm = (img - vmin)/(vmax-vmin)
img_norm = img_norm.clip(0.0, 1.0)
img_byte[:] = img_norm * 255
imwrite(img_byte, filename)