from functools import wraps
import multiprocessing
import numpy
import logging
import geodat.units
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
try:
import pyferret
PYFERRET_INSTALLED = True
_IMPORT_PYFERRET_ERROR = None
except ImportError:
logger.warning("Failed to load pyferret.")
PYFERRET_INSTALLED = False
_IMPORT_PYFERRET_ERROR = ImportError("Failed to load pyferret")
[docs]def num2fer(data, coords, dimunits,
varname="UNKNOWN", data_units=None, missing_value=None,
cartesian_axes=None, dimnames=None):
''' Create a dictionary that resemble the Ferret
data variable structure to be passed to pyferret.putdata
Args:
data (numpy.ndarray)
coords (a list of numpy.ndarray)
dimunits (a list of str): dimension units (e.g. ['months','degrees_N'])
varname (str, optional)
data_units (str, optional)
missing_value (numeric)
cartesian_axes (a list of characters): specifies the cartesian axes
e.g. ['T','Y','X']. If this is not specified, guesses will be made
using the dimension units (say unit month will be interpreted as a
[T]IME axis. Specifying cartesian_axes overwirtes the guesses.
dimnames (a list of str) - dimension names (e.g. ['time','lat','lon'])
Return:
dict
Length of cartesian_axes, dimnames, dimunits and coords need
to agree with the number of dimensions of data
'''
if not PYFERRET_INSTALLED:
raise _IMPORT_PYFERRET_ERROR
if len(dimunits) != data.ndim:
raise Exception("Number of dimunits does not match data.ndim")
if len(coords) != data.ndim:
raise Exception("Number of coords does not match data.ndim")
fer_var = {}
# Define the variable
fer_var['data'] = data.copy()
# Variable name
fer_var['name'] = varname
# Dataset
fer_var['dset'] = None
# Title = variable name
fer_var['title'] = fer_var['name']
# Set missing value
if missing_value is not None:
fer_var['missing_value'] = missing_value
# Set data unit
if data_units is not None:
fer_var['data_unit'] = data_units
# Determine the axis type
cax2ax_type = {'X': pyferret.AXISTYPE_LONGITUDE,
'Y': pyferret.AXISTYPE_LATITUDE,
'Z': pyferret.AXISTYPE_LEVEL,
'T': pyferret.AXISTYPE_CUSTOM}
# Make guessses for the axis type
if cartesian_axes is None:
cartesian_axes = [geodat.units.assign_caxis(dimunit)
for dimunit in dimunits]
if len(cartesian_axes) != data.ndim:
raise Exception("Number of cartesian_axes/dimunits does"+\
" not match data.ndim")
# Convert it to PyFerret convention
fer_var['axis_types'] = [cax2ax_type[cax]
if cax in cax2ax_type.keys()
else pyferret.AXISTYPE_NORMAL
for cax in cartesian_axes]
if dimnames is not None:
if len(dimnames) != data.ndim:
raise Exception("Number of dimnames does not match data.ndim")
fer_var['axis_names'] = dimnames
fer_var['axis_units'] = dimunits
fer_var['axis_coords'] = coords
# This will be used as the second argument to pyferret.putdata
axis_pos_dict = {'X': pyferret.X_AXIS,
'Y': pyferret.Y_AXIS,
'Z': pyferret.Z_AXIS,
'T': pyferret.T_AXIS}
# Force axis position
fer_var['axis_pos'] = [axis_pos_dict[cax]
if cax in axis_pos_dict.keys()
else cartesian_axes.index(cax)
for cax in cartesian_axes]
return fer_var
[docs]def fer2num(var):
''' Filter the dictionary returned by pyferret.getdata
PyFerret usually returns data with extra singlet dimension
Need to filter those
Args:
var (dict): as is returned by pyferret.getdata
Returns:
dict: {'data': a numpy ndarray, 'varname': the name of the variable,\n
'coords': a list of numpy ndarrays for the dimensions,
'dimunits': a list of strings, the units for the dimensions,
'dimnames': a list of strings, the names for the dimensions}
'''
if not PYFERRET_INSTALLED:
raise _IMPORT_PYFERRET_ERROR
results = {}
results['coords'] = [ax for ax in var['axis_coords']
if ax is not None]
if var['axis_names'] is not None:
results['dimnames'] = [var['axis_names'][i]
for i in range(len(var['axis_names']))
if var['axis_coords'][i] is not None]
# If the axis_type is TIME, the axis_unit is the calendar type which
# is not considered yet
if pyferret.AXISTYPE_TIME in var['axis_types']:
raise Exception("Immature function: axis_type from Ferret is TIME,"+\
"not CUSTOM; a situation not taken into yet.")
results['dimunits'] = [var['axis_units'][i]
for i in range(len(var['axis_units']))
if var['axis_coords'][i] is not None]
sliceobj = [0 if ax is None else slice(None)
for ax in var['axis_coords']]
results['data'] = var['data'][sliceobj]
results['varname'] = var['title']
return results
[docs]def run_worker(f):
''' A workaround for clearing memory used by PyFerret
'''
@wraps(f)
def run_func(*args, **kwargs):
P = multiprocessing.Pool(1)
result = P.apply(f, args, kwargs)
P.close()
P.terminate()
P.join()
return result
return run_func
[docs]def regrid_once_primitive(var, ref_var, axis,
verbose=False, prerun=None, transform='@ave'):
''' A generic function that regrids a variable without the dependence of
geodat.nc.Variable
Args:
var (dict) : arguments for num2fer
Required keys: data,coords,dimunits
ref_var (dict) : arguments for num2fer.
This supplies the grid for regridding
Required keys: coords,dimunits
axis (str) : the axis for regridding e.g. 'X'/'Y'/'XY'/"YX"
verbose (bool) : whether to print progress (default: False)
prerun (a list of str) : commands to be run at the start (default: None)
transform (str): "@ave" (Conserve area average),
"@lin" (Linear interpolation),...see Ferret doc
Returns:
dict
'''
if not PYFERRET_INSTALLED:
raise _IMPORT_PYFERRET_ERROR
pyferret.start(quiet=True, journal=verbose,
verify=False, server=True)
# commands to run before regridding
if prerun is not None:
if type(prerun) is str:
pyferret.run(prerun)
elif type(prerun) is list:
for s in prerun:
if type(s) is str:
pyferret.run(prerun)
else:
raise Exception("prerun has to be either a string or "+\
"a list of string")
else:
raise Exception("prerun has to be either a string or a list of "+\
"string")
assert isinstance(axis, str)
axis = axis.upper()
# Make sure axis is a string denoting X or Y axis
#if axis not in ['X', 'Y', 'XY', 'YX']:
# raise Exception("Currently axis can only be X/Y/XY")
# Construct the source data read by pyferret.putdata
source_fer = num2fer(varname="source", **var)
# Fill in unnecessary input for Ferret
if "data" not in ref_var:
ref_var['data'] = numpy.zeros((1,)*len(ref_var['coords']))
# Construct the destination data read by pyferret.putdata
dest_fer = num2fer(varname="dest", **ref_var)
if verbose:
print source_fer
print dest_fer
pyferret.putdata(source_fer, axis_pos=source_fer['axis_pos'])
if verbose:
print "Put source variable"
pyferret.run('show grid source')
pyferret.putdata(dest_fer, axis_pos=dest_fer['axis_pos'])
if verbose:
print "Put destination variable"
pyferret.run('show grid dest')
pyfer_command = 'let result = source[g'+axis.lower()+'=dest'+transform+']'
pyferret.run(pyfer_command)
if verbose:
print "Regridded in FERRET"
pyferret.run('show grid result')
# Get results
result_ref = pyferret.getdata('result')
if verbose: print "Get data from FERRET"
# Convert from ferret data structure to geodat.nc.Variable
tmp_result = fer2num(result_ref)
if 'varname' in var:
tmp_result['varname'] = var['varname']
tmp_caxes = [geodat.units.assign_caxis(dimunit)
for dimunit in tmp_result['dimunits']]
var_caxes = [geodat.units.assign_caxis(dimunit)
for dimunit in var['dimunits']]
# Preserve dimension order (Ferret reverts the order)
neworder = [tmp_caxes.index(cax)
for cax in var_caxes]
# Change the dimension order of the result to match with the input
tmp_result['coords'] = [tmp_result['coords'][iax] for iax in neworder]
tmp_result['dimunits'] = [tmp_result['dimunits'][iax] for iax in neworder]
if 'dimnames' in tmp_result:
tmp_result['dimnames'] = [tmp_result['dimnames'][iax]
for iax in neworder]
tmp_result['data'] = tmp_result['data'].transpose(neworder).astype(
var['data'].dtype)
# Return the input var with the data and dimensions replaced by
# the regridded ones
var.update(tmp_result)
result = var
status = pyferret.stop()
if verbose:
if status:
print "PyFerret stopped."
else:
print "PyFerret failed to stop."
return result
regrid_primitive = run_worker(regrid_once_primitive)
if __name__ == '__main__':
import scipy.io.netcdf as netcdf
ncfile_low = netcdf.netcdf_file("land_mask_lowres.nc")
newvar = dict(data=ncfile_low.variables['land_mask'].data,
coords=[ncfile_low.variables[dim].data
for dim in ncfile_low.variables['land_mask'].\
dimensions],
dimunits=[ncfile_low.variables[dim].units
for dim in ncfile_low.variables['land_mask'].\
dimensions])
ncfile_high = netcdf.netcdf_file("land_mask_highres.nc")
var_high = dict(data=ncfile_high.variables['land_mask'].data,
coords=[ncfile_high.variables[dim].data
for dim in ncfile_high.variables['land_mask'].\
dimensions],
dimunits=[ncfile_high.variables[dim].units
for dim in ncfile_high.variables['land_mask'].\
dimensions])
regridded = regrid_primitive(var_high, newvar, 'XY')