Source code for

"""This module contains the main classes for handling geophysical (climate)
variables and dimensions.  It also reads and writes NetCDF files.

The :mod:`` and :mod:`` classes in this
module act as containers of :py:mod:`~numpy` arrays which can be easily
from __future__ import print_function

import os
import sys
import copy
import warnings
import logging
from functools import wraps, partial
import datetime
import inspect

import numpy
import as netcdf
from scipy.ndimage.filters import gaussian_filter
import pylab

from dateutil.relativedelta import relativedelta

from . import keepdims
from . import arrays
from . import stat
from . import math
from . import monthly
from .plot import mapplot
from . import grid_func
from . import pyferret_func
from . import units

logger = logging.getLogger(__name__)

# If netCDF4 is not installed, some functions are not available
# When these functions are called, an Exception will be raised
    import netCDF4 as _netCDF4
except ImportError:

def _throw_error(error):
    def new_func(*args,**kwargs):
        raise error
    return new_func

    _num2date = _netCDF4.num2date
    _date2num = _netCDF4.date2num
    _netCDF4_Dataset = _netCDF4.Dataset
    _netCDF4_datetime = _netCDF4.netcdftime.datetime
    logger.warning("Failed to import netCDF4 package. "+\
                   "Some functions may not work")
    _NETCDF4_IMPORT_ERROR = ImportError("The netCDF4 package is "+\
                                        "required but fail to import. "+\
    _num2date = _date2num = _netCDF4_Dataset = \

# Finished import setup

[docs]def getvar(filename, varname, *args, **kwargs): ''' Short hand for retrieving variable from a netcdf file Args: filename (str): Name of input file varname (str): Name of the variable Returns: Optional arguments and keyword arguments are parsed to Example: var = getvar("","sst") ''' return Variable(netcdf.netcdf_file(filename), varname, *args, **kwargs)
def dataset(filenames, append_code="s", *args, **kwargs): ''' Extract all variables in one file or more files Args: filenames (str or a list of str): Input files append_code (str): what to do if when variable names clash; "o" for overwriting previously loaded variables; "r" for renaming newly loaded variable (will prompt for input) "s" to skip (default) Returns: dict: str and pairs Optional arguments accepted by can be used here ''' result = {} if type(filenames) is not list: filenames = [filenames] for filename in filenames: file_handle = netcdf.netcdf_file(filename) for varname in file_handle.variables.keys(): if varname in file_handle.dimensions: # Do not add dimensions to the dataset continue if varname in result: print(varname, "alread loaded. ", end="") if append_code.lower() == 'o': print("Overwriting.") result[varname] = Variable(file_handle, varname, *args, **kwargs) elif append_code.lower() == 'r': print("Enter new name: ", end="") newname = sys.stdin.readline()[:-1] result[newname] = Variable(file_handle, varname, *args, **kwargs) result[newname].varname = newname elif append_code.lower() == "s": print("I am skipping the variable:", varname, "in", filename) continue else: raise ValueError("Invalid choice for append_code") else: result[varname] = Variable(file_handle, varname, *args, **kwargs) return result def _genereal_axis(axis): ''' Standardize keyword for axes time/lat/lon 'T': 'time','t','TIME','T' 'X': 'x','X','lon','LON','longitude','LONGITUDE' 'Y': 'y','Y','lat','LAT','latitude','LATITUDE' Anything not recognized will be returned in upper case Args: axis (str) Returns: str Example: _genereal_axis('time') -> 'T' Example: _genereal_axis('dummy') -> 'DUMMY' ''' invaxnames = {'tim':'T', 'lon':'X', 'lat':'Y', 'lev':'Z', 'dep':'Z'} if len(axis) > 1: return invaxnames[axis[:3].lower()] else: return axis.upper() def _general_region(region): ''' Standardize keyword for regional slicing. Use by Variable.getRegion() 'T': 'time','t','TIME','T' 'X': 'x','X','lon','LON','longitude','LONGITUDE' 'Y': 'y','Y','lat','LAT','latitude','LATITUDE' Args: region (dict) Returns: dict Example: _general_region({'TIME':(10.,1000.),'LAT':(-5.,5.)}) --> {'T': (10.,1000.),'Y':(-5.,5.)} ''' results = {} for iax, value in region.items(): results[_genereal_axis(iax)] = value return results
[docs]class Dimension(object): """ A container for handling physical dimensions such as time, latitude. It can be indexed/sliced the same way as indexing a numpy array """ def __init__(self, data, dimname=None, units=None, attributes=None, parent=None): """ Attributes: data (numpy 1-d array): Array for the physical axis dimname (str): Name of the dimension, e.g. "time" units (str): Unit of the dimension, e.g. "days since 1990-01-01" attributes (dict): Attributes for the dimension Arguments: data (numpy 1-d array): Array for the physical axis dimname (str): Name of the dimension, e.g. "time" units (str): Unit of the dimension, e.g. "days since 1990-01-01" attributes (dict): Attributes for the dimension parent (Dimension): from which dimname,units,attributes are copied if they are not supplied already in the arguments """ = data self.units = units self.attributes = {} self.dimname = 'UNNAMED_DIM' if parent is not None: self.units = parent.units self.attributes.update(parent.attributes) self.dimname = parent.dimname if units is not None: self.units = units if attributes is not None: self.attributes.update(attributes) if dimname is not None: self.dimname = dimname if isinstance(data, netcdf.netcdf_variable): = self.units = getattr(data, 'units', None) self.attributes.update(data.__dict__['_attributes']) # Make sure the dimension data is a numpy array if numpy.isscalar( = numpy.array([,], dtype=getattr(,"dtype",None)) if attributes is not None: self.attributes.update(attributes) self.attributes.update(dict(units=str(self.units)))
[docs] def __getitem__(self, sliceobj): ''' Apply the slice object on the data (numpy.ndarray) ''' return Dimension([sliceobj], dimname=self.dimname, units=self.units, attributes=self.attributes)
[docs] def info(self, detailed=False, file_out=None): """ Print brief info about the dimension if detailed is True, attributes and length of axis are also printed """ info_str = 'Dim: '+ self.dimname if numpy.isscalar( info_str += ' = '+str( else: info_str += ' = '+ str([0]) + ':' + str([-1]) info_str += ' Unit: ' + str(self.units) print(info_str, file=file_out) if detailed: print('Length=', str(len(, file=file_out) print('Attributes:', file=file_out) for attname, val in self.attributes.items(): print(" {dimname}:{attname} = {val}".format( dimname=self.dimname, attname=attname, val=val), file=file_out)
[docs] def getCAxis(self): """ Get cartesian axis (T/Z/Y/X) for a dimension instance if the dimension has a cartesian_axis attribute, the value of the attribute is returned. Otherwise, the unit is used as a clue Example: dim.setattr("cartesian_axis","X") dim.getCAxis() --> "X" Example: dim.units = "months" dim.getCAxis() --> "T" Example: dim.units = "degreeN" dim.getCAxis() --> "Y" """ atts = self.attributes cax = atts.get('axis', atts.get('cartesian_axis', None)) if cax is None: if self.units is not None: cax = units.assign_caxis(self.units) return cax
[docs] def setattr(self, att, value): ''' Java style setter for attributes ''' self.attributes[att] = value
[docs] def getattr(self, att, default=None): ''' Java style getter for attributes ''' return self.attributes.get(att, default)
[docs] def is_monotonic(self): ''' Return True if the axis is monotonic, False otherwise ''' strict_monotonic_func = lambda data: (numpy.diff(data) > 0.).all() or \ (numpy.diff(data) < 0.).all() strict_monotonic = strict_monotonic_func( if not strict_monotonic and self.getCAxis() == 'X': # Make sure it is not because of periodic boundary condition # of longitude x_diff = numpy.diff( if sum(x_diff > 0.) > sum(x_diff < 0.): # More often increasing # Roll backward return strict_monotonic_func( numpy.roll(, (numpy.argmin(x_diff)+1)*-1)) else: # More often decreasing # Roll forward return strict_monotonic_func( numpy.roll(, numpy.argmin(x_diff)+1)) return strict_monotonic
[docs] def is_climo(self): ''' Return True if the axis is a climatological time axis ''' if self.getCAxis() != 'T': return False return all([ x == y for x,y in zip( sorted(self.getDate("m", True)), range(1, 13)) ])
[docs] def time2array(self): ''' Given a dimension object, if it is a time axis, return ndarray of size (N,6) where N is the number of time point, and the six indices represent: YEAR,MONTH,DAY,HOUR,MINUTE,SECOND Same as getDate() ''' return self.getDate()
[docs] def time0(self): ''' Return a datetime.datetime object referring to the t0 of a time axis ''' if self.getCAxis() != 'T': raise Exception("This axis is not a time axis") startpoint = self.units.split(' since ')[-1] # choose the appropriate time format since some data # does not specify the hour/minute/seconds if len(startpoint.split()) > 1: if len(startpoint.split()[-1].split(':')) == 3: date_format = '%Y-%m-%d %H:%M:%S' else: startpoint = startpoint.split()[0] date_format = '%Y-%m-%d' else: date_format = '%Y-%m-%d' return datetime.datetime.strptime(startpoint, date_format)
[docs] def getDate(self, toggle="YmdHMS", no_continuous_duplicate_month=False): ''' Return the time axis date in an array format of "Year,Month,Day,Hour,Minute,Second" Toggle one or many among Y/m/d/H/M/S to select a particular time format Args: toggle (iterable of str): each item should be among Y/m/d/H/M/S no_continuous_duplicate_month (bool): used for toggle=="m" only no_continuous_duplicate_month will check if there are adjacent months that are identical. If so, check if the data is a monthly series and correct for the duplicates Examples: >>> # return an array of the month of the time axis >>> var.getDate("m") array([ 1, 2, 3, 4, 5, 6 ]) >>> # return an array with the first column showing the years, >>> # second column showing the months, third column >>> # for days >>> getDate("Ymd") array([[ 1990, 1, 15 ], [ 1990, 2, 15 ], [ 1990, 3, 15 ]]) ''' #------------------ # Sanity check #------------------ if self.getCAxis() != 'T': raise RuntimeError("Dimension.getDate: not a time axis") if no_continuous_duplicate_month: if toggle != 'm': raise RuntimeError("no_continuous_duplicate_month only "+\ "applies when toggle=m") try: _ = iter(toggle) except TypeError: raise TypeError("toggle has to be iterable:\"Y/m/d/H/M/S\"") if not all( [ t in "YmdHMS" for t in toggle]): raise ValueError("toggle has to be one of \"Y/m/d/H/M/S\"") #-------------------------------------------------------- # Convert time values to datetime objects using netCDF4 #-------------------------------------------------------- units = self.units.split()[0] units = units if units.endswith("s") else units+"s" if units != 'months': alltimes = _num2date(, self.units, self.attributes.get( 'calendar', 'standard').lower()) try: _ = iter(alltimes) except TypeError: alltimes = [alltimes, ] else: # netcdftime does not handle month as the unit if 'since' not in self.units: raise Exception("the dimension, assumed to be a time"+\ " axis, should have a unit such as "+\ "\"days since 01-JAN-2000\"") if not self.is_monotonic(): raise ValueError("The axis is not monotonic!") if (numpy.diff( < 0.).all(): raise ValueError("Time going backward...really?") t0 = self.time0() # is a datetime.datetime object # at this point we knew the unit is month alltimes = [t0 + relativedelta(months=int(t)) for t in] # Convert flag to attribute names flag2attr = dict(Y="year", m="month", d="day", H="hour", M="minute", S="second") #------------------------------------ # Compute and return results #------------------------------------- #------------------------------------- # Case 1: Return everything toggled #------------------------------------- if toggle != "m": return numpy.array([[ getattr(t, flag2attr[flag]) for flag in toggle ] for t in alltimes ]).squeeze() # toggle == "m" - Extract months only all_months = numpy.array([ t.month for t in alltimes ]) # Difference between months month_diff = numpy.diff(all_months) if (month_diff != 0).all(): #------------------------------------------------- # Case 2: no continuous duplicate month, return #------------------------------------------------- return all_months elif not no_continuous_duplicate_month: #--------------------------------------------------------- # Case 3: not correcting duplicate month, warn and return #--------------------------------------------------------- logger.warning("There are continuous duplicated months "+\ "but not correcting for them.") return all_months else: #------------------------------------------------------------------ # Case 4: there are continuous duplicate months, and the user wants # to correct for them; useful for monthly data analysis #------------------------------------------------------------------ # Are we really dealing with monthly data? month_delta = _date2num(self.time0() + relativedelta(months=1), self.units, self.attributes.get( 'calendar', 'standard').lower()) # Average time step avg_dt = numpy.diff( if avg_dt/month_delta < 0.5 or avg_dt/month_delta > 1.5: raise RuntimeError("There are continuous duplicated months "+\ "and no_continuous_duplicate_month is "+\ "True. However it does not seem to be "+\ "a monthly time series.") #------------------------------------------------- # OK. So we are dealing with monthly data #------------------------------------------------- # If most of the time stamps are close to the end of a calendar # month, the problem can be fixed by rolling the time axis # backward or forward by half a month all_days = numpy.array([ for t in alltimes]) if not (sum(numpy.logical_or( all_days > 25, all_days < 5)) > len(all_days)/2): raise RuntimeError("There are duplicated months and "+\ "no_continuous_duplicate_month is True. "+\ "But the calendar days provide no hint "+\ "for correction.") move_backward = sum(all_days > 25) > sum(all_days < 5) move_delta = relativedelta(days=15) if isinstance(alltimes[0], _netCDF4_datetime): date2num_f = partial(_date2num, units=self.units, calendar=self.attributes.get( 'calendar', 'standard').lower()) num2date_f = partial(_num2date, units=self.units, calendar=self.attributes.get( 'calendar', 'standard').lower()) if move_backward: alltimes = num2date_f(date2num_f(alltimes)-\ date2num_f(self.time0()+move_delta)) else: alltimes = num2date_f(date2num_f(alltimes)+\ date2num_f(self.time0()+move_delta)) else: # alltimes are python datetime.datetime object # can be added to relativedelta directly if move_backward: alltimes = numpy.array(alltimes)-move_delta else: alltimes = numpy.array(alltimes)+move_delta all_months = numpy.array([t.month for t in alltimes]) if (numpy.diff(all_months) != 0).all(): logger.warning("Months are computed by shifting the time "+\ "axis {}".format("backward" if move_backward else "forward")) return all_months else: raise RuntimeError("Failed to correct for continuous "+\ "duplicated months")
[docs]class Variable(object): """ A container for handling physical variable together with its dimensions so that while the variable is manipulated (e.g. averaged along one axis), the information of the dimensions change accordingly. It can be indexed/sliced the same way as indexing a numpy array """ def __init__(self, reader=None, varname=None, data=None, dims=None, attributes=None, history=None, parent=None, ensureMasked=False, **kwargs): """ Attributes: data (numpy.ndarray or Data array of the variable varname (str): Name of the variable dims (list of Dimension instances): Dimensions of the variable consistent with the shape of the data array units (str): Unit of the variable attributes (dict): Attributes of the variable Arguments: reader (netcdf.netcdf_file, optional): if given, the variable is read from the NetCDF file varname (str) : variable name data (numpy.ndarray or dims (a list of Dimension) : dimensions attributes (dict): attributes of the variables history (str): to be stored/appended to attributes['history'] parent (Variable): from which varname, dims and attributes are copied; Copied `varname` and `dims` can be overwritten by assigning values in the arguments. If `attributes` is copied from `parent`, the dictionary assigned to the argument `attributes` is used to update the copied `attributes`. `parent` is left unchanged. ensureMasked (bool): whether the array is masked using _FillValue upon initialization. default: False Other keyword arguments would be parsed to getRegion Examples: >>> var = Variable(netcdf.netcdf_file,"temperature") >>> var = Variable(netcdf.netcdf_file,"temperature", lat=(-5.,5.),lon=(-170.,-120)) >>> # Copy varname, dims, attributes from var >>> # If the dimension shape does not match data shape, raise an Error >>> var2 = Variable(data=numpy.array([1,2,3,4]),parent=var) >>> var = Variable(data=numpy.array([1,2,3,4]), dims=[Dimension(data=numpy.array([0.,1.,2.,3.]),)], varname='name') """ # Initialize the most basic properties. # Anything else goes to the attribute dictionary = data self.dims = dims # a list of Dimension instances self.varname = varname self.attributes = {} self._ncfile = None if reader is not None and type(reader) is netcdf.netcdf_file: assert varname is not None try: varobj = reader.variables[varname] except KeyError: print('Unknown variable name. Available variables: '+\ ','.join(reader.variables.keys())) return None = getattr(varobj, "data", None) self.dims = [Dimension(reader.variables[dim], dim) if dim in reader.variables else Dimension(reader.dimensions[dim], dim) for dim in varobj.dimensions] self.attributes.update(varobj.__dict__['_attributes']) self.addHistory('From file: ' self._ncfile = reader elif parent is not None: self._copy_from_parent_(parent) else: # no recognizable reader or parent variable is given; #data, varname should not be None if data is None: raise AttributeError('data is not provided') if varname is None: raise AttributeError('varname is not provided') # If parent is given, these will overwrite the properties # copied from parent # If parent is not given, the following initializes the instances if data is not None: = data if dims is not None: self.dims = dims if varname is not None: self.varname = varname if attributes is not None: self.attributes.update(attributes) if self.dims is None: raise AttributeError("dims (dimensions) is not provided") if history is not None: self.addHistory(history) self.setRegion(**kwargs) self.masked = False # This is the one that takes the time while initializing variables if ensureMasked: self.ensureMasked() # Check to make sure the variable data shape matches the dimensions' if not self.is_shape_matches_dims(): raise ValueError("Dimension mismatch.")
[docs] def is_shape_matches_dims(self): ''' Check if the shape of the data matches the dimensions Raise ValueError if the dimensions do not match ''' var_data_shape = dim_shape = tuple([ for dim in self.dims]) if var_data_shape != dim_shape: return False else: return True
[docs] def addHistory(self, string): """ Append to the history attribute in the variable. If history doesn't exist, create one """ history = self.attributes.get('history', '') newhistory = history + '; '+ string self.setattr('history', newhistory)
def __repr__(self): result = "<{}.{} ".format(__name__, type(self).__name__) + \ self.varname +\ '(' + ",".join(self.getDimnames()) + '), shape: ' +\ str( + '>' return result
[docs] def info(self, detailed=False, file_out=None): """ Print brief info about the variable """ # varname, dim, shape print(self.__repr__(), file=file_out) # Attributes: print("Attributes:", file=file_out) for attname, val in self.attributes.items(): print(" {varname}:{attname} = {val}".format( varname=self.varname, attname=attname, val=val), file=file_out) # Dimension info: for dim in self.dims:, file_out=file_out)
[docs] def getCAxes(self): """ get the cartesian axes for all the dimensions. Return a list of cartesian axes. if it is undefined, replace with dummy: A,B,C,...(excludes: T/Z/X/Y) """ dummies = list('ABCDEFGHIJKLMNOPQRSUVW') caxes = [] for dim in self.dims: # Using try-catch is clearly not ideal # Previously the try block was an if-statement that # getCAxis is called only if dim is an instance of # However when the module is reload, # objects created before reloading is no longer an # instance of the reloaded module try: cax = dim.getCAxis() except AttributeError: cax = None if cax is None: cax = dummies.pop(0) caxes.append(cax) return caxes
[docs] def getDimnames(self): """Return a list of dimension names """ return [dim.dimname for dim in self.dims]
[docs] def setattr(self, att, value): '''Set the value of an attribute of the variable Java style setter ''' self.attributes[att] = value
[docs] def getattr(self, att, default=None): ''' Return the value of an attribute of the variable Java style getter ''' return self.attributes.get(att, default)
[docs] def getAxes(self): ''' Return the dimensions of the variable as a list of numpy arrays In the order of dimension ''' axes = [] for idim, dim in enumerate(self.dims): if is None: = numpy.arange(1,[idim]+1) axis = axes.append(axis) return axes
[docs] def getIAxis(self, axis): '''Return the integer for the required cartesian axis Input: axis (int or str): if it is an integer, do nothing and return axis if it is a str, look for index of the dimension which matches the required cartesian axis using the CAxes function Returns: int See Also: CAxes, getAxis, getDim ''' if isinstance(axis, int): return axis if isinstance(axis, str): caxes = self.getCAxes() axis = _genereal_axis(axis) if axis not in caxes: raise KeyError(self.varname+" has no "+axis+" axis") else: return caxes.index(axis) else: raise ValueError("axis has to be either an integer or a string")
[docs] def getAxis(self, axis): ''' Return a numpy array of an axis of a variable Input: axis (int or str): if it is an integer, do nothing and return axis if it is a str, look for index of the dimension which matches the required cartesian axis using the CAxes function Returns: numpy array See Also: CAxes, getIAxis, getDim ''' return self.getDim(axis).data
[docs] def getDim(self, axis): ''' Return a the Dimension instance of an axis of a variable Input: axis (int or str): if it is an integer, do nothing and return axis if it is a str, look for index of the dimension which matches the required cartesian axis using the CAxes function Returns: numpy array See Also: CAxes, getIAxis, getAxis ''' return self.dims[self.getIAxis(axis)]
[docs] def getDomain(self, axis=None): ''' Return the domain of the variable If the axis is a longitude axis, make all negative degree positive (only for output; the variable longitude data is unchanged) Args: axis (str or int): query the domain of a particular dimension. If it is not specified, the domains of all dimensions are returned Returns: dict Examples: >>> # var is a regional variable within (20S-20N, 140E-140W) >>> var.getDomain() {"X": (140.,220.), "Y": (-20.,20.)} >>> var.getDomain("X") {"X": (140.,220.)} ''' if axis is None: axis = self.getCAxes() domain = {} for ax_name in axis: coor = self.getAxis(ax_name) if ax_name == 'X': coor = coor.copy() coor[coor < 0.] += 360. domain[_genereal_axis(ax_name)] = (min(coor), max(coor)) return domain
def _copy_from_parent_(self, parent): """ Copy the dimensions, attributes and varname from a parent variable Use copy.copy instead of deepcopy """ if not isinstance(parent.dims, list): raise TypeError("parent.dims must be a list") if any([not isinstance(dim, Dimension) for dim in parent.dims]): raise TypeError("parent.dims must be a list of Dimension instance") self.dims = copy.copy(parent.dims) if not isinstance(parent.attributes, dict): raise TypeError("parent.attributes must be a dict instance") self.attributes = copy.copy(parent.attributes) if not isinstance(parent.varname, str): raise TypeError("parent.varname must be a string instance") self.varname = copy.copy(parent.varname) def _broadcast_dim_(self, other, result): ''' Return a list of dimensions suitable for operations (__add__...) between self and other Arg: other ( or numpy.ndarray attribute) result (numpy.ndarray) : the result of an operation e.g. __add__ Returns: a list of Example: varA < with shape (12,10)> varB < with shape (1,10)> varC = varA + varB varC < with shape (12,10)> inherits dimensions from varA But varC = varB + varA would require inheriting the first dimension of varA and the second dimension of varB ''' dims = [] for idim, (size1, size2) in enumerate(zip(, result.shape)): if size1 == 1 and size2 > 1: dims.append(other.dims[idim]) else: dims.append(self.dims[idim]) return dims def __sub__(self, other): var1 = _getdata(self) var2 = _getdata(other) history = "" name1 = getattr(self, 'varname', str(self)) name2 = getattr(other, 'varname', str(other)) data = var1 - var2 history = name1 + '-' + name2 return Variable(data=data, dims=self._broadcast_dim_(other, data), parent=self, history=history) def __rsub__(self, other): var1 = _getdata(self) var2 = _getdata(other) history = "" name1 = getattr(self, 'varname', str(self)) name2 = getattr(other, 'varname', str(other)) data = var2 - var1 history = name2 + '-' + name1 return Variable(data=data, dims=self._broadcast_dim_(other, data), parent=self, history=history) def __add__(self, other): var1 = _getdata(self) var2 = _getdata(other) history = "" name1 = getattr(self, 'varname', str(self)) name2 = getattr(other, 'varname', str(other)) data = var1 + var2 history = name1 + '+' + name2 return Variable(data=data, dims=self._broadcast_dim_(other, data), parent=self, history=history) def __radd__(self, other): return self.__add__(other) def __div__(self, other): var1 = _getdata(self) var2 = _getdata(other) history = "" name1 = getattr(self, 'varname', str(self)) name2 = getattr(other, 'varname', str(other)) data = var1 / var2 history = name1 + '/' + name2 return Variable(data=data, dims=self._broadcast_dim_(other, data), parent=self, history=history) def __rdiv__(self, other): var1 = _getdata(self) var2 = _getdata(other) history = "" name1 = getattr(self, 'varname', str(self)) name2 = getattr(other, 'varname', str(other)) data = var2 / var1 history = name2 + '/' + name1 return Variable(data=data, dims=self._broadcast_dim_(other, data), parent=self, history=history) def __mul__(self, other): var1 = _getdata(self) var2 = _getdata(other) history = "" name1 = getattr(self, 'varname', str(self)) name2 = getattr(other, 'varname', str(other)) data = var1 * var2 history = name1 + '*' + name2 return Variable(data=data, dims=self._broadcast_dim_(other, data), parent=self, history=history) def __rmul__(self, other): return self.__mul__(other)
[docs] def __call__(self, **region): '''Same as self.getRegion''' return self.getRegion(**region)
def __getitem__(self, sliceobj): a = Variable(, varname=self.varname, parent=self) sliceobj = numpy.index_exp[sliceobj] a.slicing(sliceobj) a.addHistory('__getitem__['+str(sliceobj)+']') return a def __setitem__(self, sliceobj, val): if type(sliceobj) is dict: sliceobj = self.getSlice(**sliceobj)[sliceobj] = val
[docs] def getRegion(self, **kwargs): ''' Return a new Variable object within the region specified. Values have to be a length-2 iterable that specifies the range Keys "time","t","TIME","T" are all considered as "T" for time axis. Keys "x","X","lon","LON","longitude","LONGITUDE" are all considered as "X" for the longitude axis or an axis with an attribute of "cartesian_axis" set to "X" Keys "y","Y","lat","LAT","latitude","LATITUDE" are all considered as "Y" for latitude axis or an axis with an attribute of "cartesian_axis" set to "Y" Examples: >>> # Extracts the region where -20. <= latitude <= 20. >>> # and 100. <= longitude <= 200. >>> var.getRegion(lat=(-20.,20.), lon=(100.,200.)) ''' a = Variable(, parent=self) a.setRegion(**kwargs) return a
[docs] def getSlice(self, **kwargs): ''' Return a tuple of slice object corresponding a region specified. Example: variable.getSlice(lat=(-30.,30.)) ''' region = _general_region(kwargs) if region: return self._create_slice(region) else: return None
[docs] def setRegion(self, **kwargs): ''' Change the region of interest for the variable This function slices the data. ''' region = _general_region(kwargs) if region: self.slicing(self._create_slice(region)) self.addHistory('setRegion('+str(region)+')') return self
[docs] def setRegion_value(self, value, **kwargs): ''' Set values for a particular region Example: variable.setRegion_value(0.,lat=(-90.,-30.)) ''' sl = self.getSlice(**kwargs) self[sl] = value return self
def _create_slice(self, region=None): ''' Generate a tuple of slice object for the given region specifications ''' if region is None or len(region) == 0: return (slice(None),)* sliceobjs = [] # A list of 1d arrays axes = self.getAxes() caxes = self.getCAxes() for axis in caxes: sliceobj = region.get(axis, slice(None)) # sliceobj is a single value if numpy.isscalar(sliceobj): sliceobj = (sliceobj, sliceobj) if not isinstance(sliceobj, (slice, numpy.ndarray)): iax = caxes.index(axis) # Set modulo, if unset and axis is longitude, use 360 degree modulo = self.dims[iax].attributes.get('modulo', None) if axis == "X" and modulo is None: modulo = 360. sliceobj = arrays.getSlice(axes[caxes.index(axis)], sliceobj[0], sliceobj[1], modulo=modulo) sliceobjs.append(sliceobj) return tuple(sliceobjs)
[docs] def slicing(self, sliceobj): ''' Perform the slicing operation on both the data and axes Args: sliceobj (tuple): slice object Returns: None ''' ndim = =[sliceobj] num_newaxis = 0 num_ellipsis = 0 for sl in sliceobj: if sl is None: num_newaxis += 1 if sl is Ellipsis: num_ellipsis += 1 # replace Ellipsis with a number of slice(None) # such that len(sliceobj) = ndim + num_newaxis # but subsequent Ellipsis should be replaced with one slice(None) only # This list is only needed if there is any Ellipsis at all slice_None = [[slice(None)]]*(num_ellipsis-1) + \ [[slice(None)]*\ (ndim+num_newaxis-len(sliceobj)+1)] new_sliceobj = [] for sl in sliceobj: if sl is Ellipsis: new_sliceobj += slice_None.pop() else: new_sliceobj.append(sl) # Slice the Dimensions for iax, sl in enumerate(new_sliceobj): if sl is None: # numpy.newaxis is asked # create a dummy dimension self.dims.insert(iax, Dimension(data=numpy.nan)) else: self.dims[iax] = self.dims[iax][sl] # If slice is an integer, numpy would squeeze the array # Add the singlet dimension back newaxis_list = [numpy.newaxis if isinstance(sl, int) else slice(None) for sl in new_sliceobj] if newaxis_list: =[newaxis_list] # Make sure the dimensions still match if not self.is_shape_matches_dims(): raise ValueError("Dimension mismatch.")
[docs] def getLatitude(self): ''' Return a numpy array that contains the latitude axis ''' return self.getAxis('Y')
[docs] def getLongitude(self): ''' Return a numpy array that contains the longitude axis ''' return self.getAxis('X')
[docs] def getTime(self): ''' Return a numpy array that contains the time axis ''' return self.getAxis('T')
[docs] def apply_mask(self, mask): ''' mask the variable's last axes with a mask This function changes the variable ''' return apply_mask(self, mask)
[docs] def climatology(self, *args, **kwargs): ''' Compute the climatology ''' return climatology(self, *args, **kwargs)
[docs] def zonal_ave(self): ''' Compute the zonal average Same as wgt_ave('X') ''' return wgt_ave(self, 'X')
[docs] def time_ave(self): ''' Compute the time average Same as wgt_ave('T') ''' return wgt_ave(self, 'T')
[docs] def lat_ave(self): ''' Compute meridional average Same as wgt_ave('Y') ''' return wgt_ave(self, 'Y')
[docs] def area_ave(self): ''' Compute area average Same as wgt_ave('XY') ''' return wgt_ave(self, 'XY')
[docs] def wgt_ave(self, axis=None): ''' Compute averge on one or more axes Input: axis - either integer or a string (T/X/Y/Z/...) See getCAxes ''' return wgt_ave(self, axis)
[docs] def getMissingValue(self): ''' Return "missing_value" if defined in the attributes Otherwise "_FillValue" will be used as missing value If both are undefined, the numpy default for the variable data type is returned ''' FillValue = self.getattr('_FillValue', None) missing_value = self.getattr('missing_value', None) default = numpy.asscalar( return missing_value or FillValue or default
[docs] def ensureMasked(self): ''' If the data in the variable is not a masked array and missing_value is present Read the data and convert the numpy ndarray into masked array, note that this will be slow if the data is large. But this will only be done once. Returns: None ''' if self.masked: return None missing_value = self.getMissingValue() if isinstance(, = =, missing_value) if not isinstance(, raise AssertionError("Missing value: {}".format(missing_value)) self.setattr('_FillValue', missing_value) self.masked = True return None
[docs] def runave(self, N, axis=0, step=None): '''Running mean along an axis. N specifies the size of the window Args: N (int or float): size of the window if axis is int, N is treated as the number of array elements along the axis if axis is str, N is treated as the absolute value of the size of window on the axis axis (int or str): axis on which running mean is computed step (int): how many array element is skipped for each sample Return: Examples: # Running average for every 5 elements on the first axis >>> var.runave(5, 0) # Running average with a window of longitudinal-width of 40-degree >>> var.runave(40., "X") # Climatological running average with a window of 3 years # axis=0 for the time axis >>> var.runave(3, 0, step=12) ''' self.ensureMasked() cartesian_axes = self.getCAxes() history = 'runave('+str(N)+','+str(axis)+',step='+str(step)+')' if type(axis) is str: axis = axis.upper() axis = cartesian_axes.index(axis) if self.dims[axis].is_monotonic(): N = N/numpy.abs(numpy.diff(self.getAxes()[axis]).mean()) else: logger.warning('''{var}'s {dim} is not monotonic. N is treated as integer'''.format(var=self.varname, dim=self.dims[axis].dimname)) if not isinstance(N, int): raise Exception('''N is treated as step. It has to be an integer.''') if N % 2 != 1: N = N + 1 return Variable(data=stat.runave(, N, axis, step), parent=self, history=history)
[docs] def squeeze(self): ''' Remove singlet dimensions ''' var = Variable(, parent=self) shape = var.dims = [var.dims[idim] for idim in range( if shape[idim] > 1] = assert == len(var.dims) var.addHistory('squeeze()') return var
[docs] def getDate(self, toggle="YmdHMS", no_continuous_duplicate_month=False): ''' Return the time axis date in an array format of "Year,Month,Day,Hour,Minute,Second" Toggle one or many among Y/m/d/H/M/S to select a particular time format Args: toggle (iterable of str): each item should be among Y/m/d/H/M/S no_continuous_duplicate_month (bool): used for toggle=="m" only no_continuous_duplicate_month will check if there are adjacent months that are identical. If so, check if the data is a monthly series and correct for the duplicates Examples: >>> # return an array of the month of the time axis >>> var.getDate("m") array([ 1, 2, 3, 4, 5, 6 ]) >>> # return an array with the first column showing the years, >>> # second column showing the months, third column >>> # for days >>> getDate("Ymd") array([[ 1990, 1, 15 ], [ 1990, 2, 15 ], [ 1990, 3, 15 ]]) ''' if 'T' not in self.getCAxes(): raise Exception("There is no recognized time axis in Variable:"+\ self.varname) return self.dims[self.getCAxes().index('T')].\ getDate(toggle=toggle, no_continuous_duplicate_month=no_continuous_duplicate_month)
def _getdata(other): ''' If the input is a, run the ensureMasked function and return the `data attribute Otherise, return the input Used by __sub__,__add__,... ''' if isinstance(other, Variable): other.ensureMasked() return else: return other def apply_mask(var, mask): ''' Mask the last axes of the data with a mask array example: apply_mask(v,land_mask>0.) The data of var is copied to a new variable that is being returned ''' newvar = Variable(, parent=var) newvar.ensureMasked()[..., mask] = return newvar
[docs]def nc_cal(func): ''' A decorator that returns a variable object Accept only function that operates on a numpy array ''' @wraps(func) def newfun(var, *args, **kwargs): history = "".join([func.__name__, args.__str__(), kwargs.__str__()]) var.ensureMasked() return Variable(data=func(, *args, **kwargs), parent=var, history=history) return newfun
[docs]def wgt_ave(var, axis=None, lat_weighted=True): '''A more general routine for averaging The method first reads the axes (x/y/z/t) needed for averaging, finds the indices corresponding these axes, then uses the geodat.stat.wgt_ave to sort the axis and do the weighted average if the axis is a "Y" axis, weights are computed using the latitude axis in the variable. if no axis is given, all axes will be averaged over. Arg: var (Variable) axis (int/str/an iterable of int or str): the dimension along which the average is computed lat_weight (bool, default True): if an area average is involved, whether a latitudinal weight based on a convergence of meridians is applied. The Y axis is assumed to have unit=degree Optional args: axis (str or a list of str or int) - axis to be averaged over weights (scalar or a numpy array) if axis is a string, e.g. "xy", the input argument weights will be overwritten E.g. (1) wgt_ave(Variable,'xy') will do the area average ''' var.ensureMasked() data = cartesian_axes = var.getCAxes() if axis is None: axis = range(len(cartesian_axes)) # If the input axis is a single integer, convert it into a list if type(axis) is int: axis = [axis] history = 'wgt_ave(axis='+','.join([str(ax) for ax in axis])+')' if type(axis) is str: axis = axis.upper() axis = [cartesian_axes.index(ax) for ax in axis] # apply varied lat_weights only if 'XY' are included caxes = [cartesian_axes[ax] for ax in axis] has_XY = 'X' in caxes and 'Y' in caxes if has_XY and lat_weighted: sliceobj = [numpy.newaxis if cax != 'Y' else slice(None) for cax in cartesian_axes] if "degree" not in var.getDim("Y").units: logger.warning("Area mean is weighted by Y axis and Y is assumed"+\ " to have unit=degreeN/degreeE") lat_weights = stat.lat_weights(var.getLatitude())[sliceobj] else: lat_weights = 1. for iax in axis: if cartesian_axes[iax] in 'XYZ': assert var.dims[iax].is_monotonic() weights = reduce(lambda x, y: x[..., numpy.newaxis]*y, [numpy.gradient(var.dims[iax].data) if iax in axis and cartesian_axes[iax] in 'XYZ' else numpy.ones_like(var.dims[iax].data) for iax in range(]) weights *= lat_weights weights =, weights) data = keepdims.mean(data*weights, axis=axis)/\ keepdims.mean(weights, axis=axis) dims = [Dimension(numpy.array([1,], dtype='i4'), var.dims[iax].dimname, units=var.dims[iax].units) if iax in axis else var.dims[iax] for iax in range(] return Variable(data=data.astype(, dims=dims, parent=var, history=history)
[docs]def wgt_sum(var, axis=None): '''A more general routine for sum The method first reads the axes (x/y/z/t) needed, finds the indices corresponding these axes, if the axis is a "Y" axis, weights are computed using the latitude axis in the variable. if no axis is given, all axes will be summed over. Args: var ( axis (str/int/list of int, optional): along which the array is summed Examples: >>> # Area sum >>> wgt_sum(var,'xy') >>> # Sum along the first axis >>> wgt_sum(var,0) ''' var.ensureMasked() data = caxes = var.getCAxes() dimnames = var.getDimnames() if axis is None: axis = range(len(dimnames)) # If the input axis is a single integer, convert it into a list if type(axis) is int: axis = [axis] history = 'wgt_sum(axis='+','.join([str(ax) for ax in axis])+')' if type(axis) is str: axis = axis.upper() axis = [caxes.index(ax) for ax in axis] has_XY = 'X' in caxes and 'Y' in caxes if has_XY: sliceobj = [numpy.newaxis if cax != 'Y' else slice(None) for cax in caxes] weights = stat.lat_weights(var.getLatitude())[sliceobj] else: weights = 1. data = data*weights data = keepdims.sum(data, axis=axis) dims = [Dimension(numpy.array([1,], dtype='i4'), var.dims[iax].dimname, units=var.dims[iax].units) if iax in axis else var.dims[iax] for iax in range(] return Variable(data=data.astype(, dims=dims, parent=var, history=history)
[docs]def time_input_to_datetime(time, calendar, units): ''' Return a datetime.datetime object given time as string Example: time_input_to_datetime("1999-01-01 00:00:00", "julian","days since 0001-01-01") ''' if isinstance(time, datetime.datetime): return time elif isinstance(time, str): try: return datetime.datetime.strptime(time, '%Y-%m-%d %H:%M:%S') except ValueError: return datetime.datetime.strptime(time, '%Y-%m-%d') else: return _num2date(time, units=units, calendar=calendar)
[docs]def time_array_to_dim(time_array, calendar, units, **kwargs): ''' Return a object given a time array time_array = [ [ year, month, day, hour, minute, second ],...] calendar = string ("standard","julian",...) units = string (e.g. "days since 0001-01-01") ''' times = numpy.array([_date2num( time_input_to_datetime( "{:04d}-{:02d}-{:02d} {:02d}:{:02d}:{:02d}".format(*time), calendar=calendar, units=units), calendar=calendar, units=units) for time in time_array]) return Dimension(data=times, units=units, attributes={'calendar':calendar}, **kwargs)
[docs]def create_monthly(calendar, units, time0, time_end=None): ''' Return a generator that return a scalar time value with the specified calendar and unit. time0 is the starting time. if time_end is not specified, the generator will not stop iterating. unit should take the form UNIT since DATETIME example: days since 0001-01-01 00:00:00 Work in progress. Need to be rewritten using relativedelta ''' time0 = time_input_to_datetime(time0, calendar=calendar, units=units) if time_end is not None: time_end = time_input_to_datetime(time_end, calendar=calendar, units=units) calendar = calendar.lower() def days_to_next_month(time): '''Hard coded number of days between calendar months TODO: Should use relativedelta''' days = [29.5, 29.5, 30.5, 30.5, 30.5, 30.5, 31.0, 30.5, 30.5, 30.5, 30.5, 31.] isleap = lambda year: (year % 4) == 0 if isleap(time.year) and (time.month == 1 or time.month == 2) and \ calendar != 'noleap': return days[time.month-1]+0.5 else: return days[time.month-1] def continue_iter(current_time, time_end): ''' Determine if the current time has passed the specified time_end''' if time_end is None: return True else: return current_time < _date2num(time_end, units=units, calendar=calendar) current_time = _date2num(time0, units=units, calendar=calendar) while continue_iter(current_time, time_end): yield current_time current_time += days_to_next_month(_num2date(current_time, units=units, calendar=calendar))
[docs]def create_climatology_dimension(calendar, units, time0=None, **dim_args): ''' Create a monthly dimension for climatology time axis Args: calendar (str) : e.g. "julian" units (str): e.g. "days since 0001-01-01 00:00:00" time0 (str): default "0001-01-16 00:00:00", the first value on the time axis Returns: Optional keyword arguments are passed to ''' if time0 is None: time0 = '0001-01-16 00:00:00' time_generator = create_monthly(calendar, units, time0) times = [ for i in range(12)] return Dimension(data=numpy.array(times), dimname='time', units=units, attributes={'modulo':""}, **dim_args)
[docs]def create_monthly_dimension(calendar, units, time0, time_end, **dim_args): time_generator = create_monthly(calendar, units, time0, time_end) times = [t for t in time_generator] return Dimension(data=numpy.array(times), units=units, **dim_args)
[docs]def create_monthly_dimension2(ref_dim=None): if ref_dim is None: time0 = datetime.datetime(1, 1, 1, 0, 0) units = 'days since 0001-1-1 0' calendar = 'standard' attributes = {'modulo':" "} else: time0 = ref_dim.time0() units = ref_dim.units calendar = ref_dim.attributes.\ get('calendar', 'standard').lower() attributes = ref_dim.attributes.copy() attributes['modulo'] = " " newaxis = [_date2num(time0+datetime.timedelta(days=int(day)), units=units, calendar=calendar) for day in numpy.linspace(15, 365-15, 12)] return Dimension(data=numpy.array(newaxis), dimname='time', units=units, attributes=attributes)
[docs]def climatology(var, appendname=False, no_continuous_duplicate_month=True, *args, **kwargs): var.ensureMasked() data = assert 'T' in var.getCAxes() months = var.getDate('m', no_continuous_duplicate_month) axis = var.getCAxes().index('T') clim_data = monthly.climatology(data=data, months=months, axis=axis, *args, **kwargs) history = 'climatology' long_name = var.getattr('long_name', '') if appendname: long_name += " climatology" dims = [dim for dim in var.dims] # units is forced to be "days since 0001-01-01 00:00:00" instead of # inheriting the var's time unit dims[axis] = create_climatology_dimension( calendar=var.dims[axis].getattr('calendar', 'standard').lower(), units='days since 0001-01-01 00:00:00', parent=var.dims[axis]) return Variable(data=clim_data, dims=dims, parent=var, history=history, attributes=dict(long_name=long_name))
[docs]def anomaly(var, appendname=False, clim=None, no_continuous_duplicate_month=True): var.ensureMasked() data = assert 'T' in var.getCAxes() months = var.getDate('m', no_continuous_duplicate_month) axis = var.getCAxes().index('T') if clim is None: anom_data = monthly.anomaly(data=data, months=months, axis=axis)[0] else: anom_data = monthly.anomaly(data=data, months=months, axis=axis,[0] history = 'anomaly' long_name = var.getattr('long_name', '') if appendname: long_name += " anomaly" dims = [dim for dim in var.dims] return Variable(data=anom_data, dims=dims, parent=var, history=history, attributes=dict(long_name=long_name))
[docs]def running_climatology(var, appendname, runave_window, step, need_anom=True): ''' Calculate the running climatology, with anomaly Args: var ( appendname (bool): whether to append "_c" to the varname of the output runave_window (int): size of the running average window step (int): step for slicing the array need_anom (bool): whether anomaly is returned Returns: climatology (, anomaly (None or if need_anom is True) Example: If the time axis is monthly, compute a running climatology with a 30-year window, with appended name and anomaly returned, like this:: >>> running_climatology(var,True,30,12,True) ''' climo = var.runave(runave_window, var.getCAxes().index('T'), step) climo.addHistory("Moving climatology with window:{}".format(runave_window)) if appendname: climo.varname += '_c' if need_anom: anom = var - climo anom.addHistory("Anomaly on a moving climatology") if appendname: anom.varname += '_a' else: anom = None return climo, anom
[docs]def clim2long(clim, target): # Copy the target time dimension time_dim = target.dims[target.getCAxes().index("T")] time_idim = clim.getCAxes().index("T") new_dim = [time_dim if idim == time_idim else dim for idim, dim in enumerate(clim.dims)] return Variable(data=monthly.clim2long(, 0, target.getDate("m", True)), dims=new_dim, attributes=clim.attributes, history="{},{})".\ format(clim.varname, target.varname), varname=clim.varname)
[docs]def concatenate(variables, axis=0): ''' Concatenate a list of variables similar to numpy.concatenate Take care of numpy masked array and concatenate dimensions as well Args: variables (list of Variable) axis (int): along which the variables are concatenated Returns: ''' for var in variables: var.ensureMasked() data =[ for var in variables], axis=axis) # Concatenate dimensions dim_data = numpy.concatenate([var.dims[axis].data for var in variables]) dims = [dim for dim in variables[0].dims] dims[axis] = Dimension(dim_data, parent=dims[axis]) return Variable(data=data, dims=dims, varname=variables[0].varname, parent=variables[0], history=variables[0].getattr('history'))
[docs]def ensemble(variables, new_axis=None, new_axis_unit=None, **kwargs): ''' Given a list of variables, perform numpy.concatenate() New axis is added as the left most axis Optional arguments: new_axis (numpy array) : for the new axis new_axis_unit (str): defines the unit of the new axis Other keyword arguments are parsed to ''' for d in variables: d.ensureMasked() ensdata =[[numpy.newaxis, ...] for d in variables], axis=0) if new_axis is None: new_axis = numpy.arange(1, len(variables)+1) dims = [Dimension(new_axis, dimname='ensemble', units=new_axis_unit),] + \ variables[0].dims return Variable(data=ensdata, parent=variables[0], dims=dims, **kwargs)
[docs]def div(u, v, varname='div', long_name='divergence', **kwargs): ''' Compute wind divergence by central difference Args: u ( - zonal wind v ( - meridional wind Returns: ''' # Longitude may be discontinuous at the dateline lon = numpy.mod(u.getLongitude(), 360.) lon = numpy.radians(lon) lat = numpy.radians(u.getLatitude()) xaxis = u.getIAxis('X') yaxis = u.getIAxis('Y') assert xaxis == v.getIAxis('X') assert yaxis == v.getIAxis('Y') # dx,dy R = 6371000. # New axis to match with lon lat_newaxis_slice = (slice(None),)*xaxis + (numpy.newaxis,) # New axis to match with lat lon_newaxis_slice = (slice(None),)*yaxis + (numpy.newaxis,) # a function of latitude dx = numpy.cos(lat)[lat_newaxis_slice]*\ numpy.gradient(lon)[lon_newaxis_slice]*R #dx_slice = (numpy.newaxis,)*yaxis + (slice(None),) \ # + (numpy.newaxis,)*( #dx = dx[dx_slice] dy = numpy.gradient(lat) * R return Variable(data=math.div(,, dx, dy, xaxis, yaxis), varname=varname, parent=u, history='divergence', attributes=dict(long_name=long_name), **kwargs)
[docs]def gradient(var, axis, mask_boundary=True, **kwargs): ''' Compute the gradient of a variable taking into account the convergence of meridians Args: var ( axis (str or int) - the axis along which the gradient is computed mask_boundary (bool, default=True) - whether boundary values are masked Additional keyword arguments are parsed to Returns: ''' if type(axis) is str: axis = var.getCAxes().index(axis.upper()) R = 6371000. if var.getCAxes()[axis] == 'X' and 'Y' in var.getCAxes(): yaxis = var.getCAxes().index('Y') lon = numpy.radians(var.getLongitude()) lat = numpy.radians(var.getLatitude()) lat_slice = (numpy.newaxis,)*yaxis + (slice(None),) \ + (numpy.newaxis,)*( lon_slice = (numpy.newaxis,)*axis + (slice(None),) \ + (numpy.newaxis,)*( # a function of latitude dx = numpy.cos(lat)[lat_slice] * numpy.gradient(lon)[lon_slice] * R else: if var.getCAxes()[axis] == 'X' or var.getCAxes()[axis] == 'Y': dx = numpy.radians(numpy.gradient(var.getAxes()[axis])) * R else: dx = numpy.gradient(var.getAxes()[axis]) return Variable(data=math.gradient(, dx, axis, mask_boundary=mask_boundary), parent=var, history='gradient: '+var.getCAxes()[axis], **kwargs)
[docs]def integrate(var, axis, varname='int', versatile=False): ''' Integrate variable along one or more axes var - axis - a list of integer that select the dimension to be integrated along ''' var.ensureMasked() if type(axis) is str: axis = axis.upper() axis = [var.getCAxes().index(ax) for ax in axis] if type(axis) is not list: axis = [axis] # Compute integration re_data = math.integrate(, axes=var.getAxes(), iax=axis) # History attribute history = 'Integrated along axis:'+ \ ''.join([var.getCAxes()[iax] for iax in axis]) # It may take some time to compute integration, notify the user if versatile: print("Integrating along axis...", end="") # This long name is probably not needed long_name = var.attributes.get('long_name', '') + \ ' integrated on ' + \ ''.join([var.getCAxes()[iax] for iax in axis]) result = Variable(data=re_data, varname=varname, parent=var, history=history, attributes=dict(long_name=long_name)) # Reduce dimension to the mean of the domain for ax in axis: result.dims[ax].data = numpy.array([var.dims[ax].data.mean()], dtype=var.dims[ax].data.dtype) if versatile: print('Done.') return result
[docs]def conform_region(*args): ''' Return a dictionary with the common lat-lon region Input: args: a list (length > 1) of dictionary or the dictionary resembles the input for Return: a dictionary {'X': (min_lon,max_lon), 'Y': (min_lat,max_lat)} ''' if len(args) == 1: raise Exception("Expect more than one domain in conform_region") args = list(args) for iarg, arg in enumerate(args): # For backward compatibility, get the domains for Variable inputs try: argdomain = arg.getDomain() except AttributeError: argdomain = arg # Generalise the form of the dictionary args[iarg] = _general_region(argdomain) minlon = max([domain.get('X', (numpy.inf*-1, numpy.inf))[0] for domain in args]) maxlon = min([domain.get('X', (numpy.inf*-1, numpy.inf))[1] for domain in args]) minlat = max([domain.get('Y', (numpy.inf*-1, numpy.inf))[0] for domain in args]) maxlat = min([domain.get('Y', (numpy.inf*-1, numpy.inf))[1] for domain in args]) return dict(lat=(minlat, maxlat), lon=(minlon, maxlon))
[docs]def conform_regrid(*args, **kwargs): ''' Given a list of variable Conform and regrid to match the region and grid Return a list of variables Unnamed optional argument go to conform_region Named optional arguments: ref - specify a reference variable to regrid to The rest go to ''' # Conform the region first region = conform_region(*args) varstoregrid = [var.getRegion(**region) for var in args] axes = 'X' if all(['X' in var.getCAxes() for var in varstoregrid]) else '' axes += 'Y' if all(['Y' in var.getCAxes() for var in varstoregrid]) else '' if 'ref' in kwargs: ref = kwargs.pop('ref').getRegion(**region) regridded = [pyferret_regrid(var, ref) for var in varstoregrid] else: # Reference is not given # The variable with the finest grid would be the reference # area = cos(theta) dtheta dphi def minarea(var, axes): mindelta = lambda v: numpy.abs(numpy.gradient(v)).min() if axes == 'XY': phi = numpy.radians(var.getLongitude()) theta = numpy.radians(var.getLatitude()) area = numpy.cos(theta)[numpy.newaxis, :]*\ numpy.gradient(phi)[:, numpy.newaxis]*\ numpy.gradient(theta)[numpy.newaxis, :] return numpy.abs(area).min() elif axes == 'X': return mindelta(var.getLongitude()) elif axes == 'Y': return mindelta(var.getLatitude()) ires = numpy.array([minarea(var, axes) for var in args]).argmin() ref = varstoregrid[ires] regridded = [pyferret_regrid(varstoregrid[i], ref, axis=axes, **kwargs) if i != ires else ref for i in range(len(varstoregrid))] return regridded
[docs]def fer2var(var): ''' Convert the dictionary returned by pyferret.getdata into a Args: var (dict): as is returned by pyferret.getdata Returns: ''' if not pyferret_func.PYFERRET_INSTALLED: raise ImportError("No pyferret installed") result = pyferret_func.fer2num(var) dims = [Dimension(data=result['coords'][i], units=result['dimunits'][i], dimname=result['dimnames'][i]) for i in range(len(result['coords']))] newvar = Variable(data=result['data'], dims=dims, varname=result['varname'], history='From Ferret') return newvar
[docs]def var2fer(var, name=None): ''' Given a, return a dictionary that resemble the Ferret data variable structure to be passed to pyferret.putdata Args: var ( name (str): optional, new variable name (default var.varname) Returns: dict: to be used by pyferret.putdata ''' if not pyferret_func.PYFERRET_INSTALLED: raise ImportError("No pyferret installed") num_input = _var_to_num_input(var) if name is not None: assert isinstance(name,str) num_input["varname"] = name return pyferret_func.num2fer(num_input)
def _var_to_num_input(var): ''' Convert a instance to a dictionary ready to be used by pyferret_func.num2fer Arg: var ( Returns: dict ''' return dict(, missing_value=var.getMissingValue(), coords=var.getAxes(), dimunits=[dim.units for dim in var.dims], varname=var.varname, data_units=var.getattr('units', ''), cartesian_axes=var.getCAxes(), dimnames=var.getDimnames())
[docs]def pyferret_regrid(var, ref_var=None, axis='XY', nlon=None, nlat=None, verbose=False, prerun=None, transform="@lin"): ''' Use pyferret to perform regridding. Args: var ( input data ref_var ( provide the target grid axis (str): which axis needs regridding nlon (int): if ref_var is not provided, a cartesian latitude-longitude global grid is created as the target grid. nlon is the number of longitudes nlat (int): number of latitude, used with nlon and when ref_axis is None verbose (bool): default False prerun (str): Ferret command to be run before the regridding transform (str): Mode of regridding. "@lin" means linear interpolation "@ave" means preserving area mean. See `Ferret doc`_ Either ref_var or (nlon and nlat) has to be specified Returns: .. _Ferret doc: ''' if not pyferret_func.PYFERRET_INSTALLED: raise ImportError("No pyferret installed") if ref_var is None: # If ref_var is not given, use nlon and nlat instead if nlon is None or nlat is None: raise Exception('''reference variable is not given. nlon and nlat need to be specified''') if ''.join(sorted(axis.upper())) != 'XY': raise Exception('''ref_var not given and therefore assumed regridding in the XY direction. The axis/axes you chose:'''+str(axis)) # Create latitude and longitude using the sphere_grid and spharm modules lon, lat = grid_func.grid_degree(NY=nlat, NX=nlon) lon = Dimension(data=lon, units="degrees_E", dimname="lon") lat = Dimension(data=lat, units="degrees_N", dimname="lat") # Create new dimensions dims = [] for idim, cax in enumerate(var.getCAxes()): if cax == 'X': dims.insert(idim, lon) elif cax == 'Y': dims.insert(idim, lat) else: dims.insert(idim, var.dims[idim]) data_shape = [[0] for dim in dims] ref_var = Variable(data=numpy.ones(data_shape,, dims=dims, parent=var) if axis == 'XY' and transform.lower() != '@ave': warnings.warn('''Regridding onto XY grid and transformation: {} is used.'''.format(transform)) # Only a slice of ref_var is needed # No need to copy the entire variable # (reduce chance of running out of memory) ref_var_slice = tuple([slice(0, 1) if cax not in axis.upper() else slice(None) for cax in ref_var.getCAxes()]) return pyferret_func.regrid_primitive( _var_to_num_input(var), _var_to_num_input(ref_var[ref_var_slice].squeeze()), axis, verbose=verbose, prerun=prerun, transform=transform)
[docs]def regrid(var, nlon, nlat, verbose=False): ''' Use spherical harmonic for regridding May produce riggles. Take an instance of, Deduce the lat-lon grid on a complete sphere, Return a regridded data on a spherical grid (nlat,nlon) Return: a instance TODO: grid.regrid now only handle 2D or 3D data, extend the function to handle rank-3+ data by flattening the extra dimension into one dimension ''' ilat = var.getCAxes().index('Y') ilon = var.getCAxes().index('X') if == 3: if verbose: print("Perform regridding on 3-D data.") otherdim = [i for i in range( if i != ilat and i != ilon][0] # new axis order: newaxorder = [ilat, ilon, otherdim] # transformed data trans_data = numpy.transpose(, newaxorder) result = grid_func.regrid(var.getLongitude(), var.getLatitude(), trans_data, nlon, nlat) # transform back newaxorder = sorted(range(, key=lambda x: newaxorder[x]) regridded = numpy.transpose(result, newaxorder) elif > 3: raise Exception('Right now the regrid function only take 2D or 3D data') else: regridded = grid_func.regrid(var.getLongitude(), var.getLatitude(),, nlon, nlat) newlon, newlat = grid_func.grid_degree(nlat, nlon) lon_d = Dimension(data=newlon, units=var.dims[ilon].units, dimname='lon') lat_d = Dimension(data=newlat, units=var.dims[ilat].units, dimname='lat') dims = [] for i in range( if i == ilat: dims.append(lat_d) elif i == ilon: dims.append(lon_d) else: dims.append(var.dims[i]) return Variable(data=regridded, dims=dims, parent=var, history='Regridded')
[docs]def gaus_filter(var, gausize): ''' Filter a variable spatially (i.e. X-Y) using a gaussian filter of size gausize Args: var ( gausize (int) - the size of the window for gaussian filtering Returns: ''' if warnings.warn('''There are masked values. They are assigned zero before filtering''') # Preserve the mask mask = var[] = 0. = False = mask newvar = Variable(data=gaussian_filter(, gausize), parent=var, history="Gaussian filter size:"+str(gausize)) return newvar
[docs]def savefile(filename, listofvar, overwrite=False, recordax=-1, appendall=False): ''' filename - a string that specifies the filename, if it is not suffixed with .nc, .nc will be added list of variable - a list of objects, can be a single overwrite - a boolean. Overwrite existing file if True. default=False recordax - an integer. Specify the axis that will be the record axis, default = -1 (no record axis) appendall - a boolean. Append to existing file if True. default = False This function, however different from other functions in the module, uses NetCDF4 Dataset to write data ''' # Handle endian endian_code = {'>':'big', '<':'little'} # if the file is not suffixed by .nc, add it if filename[-3:] != '.nc': filename += '.nc' # check if the file exists. # If it does, warn the user if the overwrite flag is not specified # savedfile = None if os.path.exists(filename): if overwrite and appendall: raise Exception('appendall and overwrite can\'t be both True.') if not overwrite and not appendall: print(filename,"exists. Overwrite or Append? [a/o]") yn = sys.stdin.readline() if yn[0].lower() == 'o': overwrite = True elif yn[0].lower() == 'a': appendall = True else: raise Exception('''File exists. Action must be either to append or to overwrite''') if not os.path.exists(filename) or overwrite: # Create temporary file ncfile = _netCDF4_Dataset(filename+'', 'w', format='NETCDF3_CLASSIC') else: # Append existing file assert appendall assert os.path.exists(filename) ncfile = _netCDF4_Dataset(filename, 'a', format='NETCDF3_CLASSIC') # Add history to the file ncfile.history = 'Created from script: '+ inspect.stack()[1][1] # if listofvar is a single object, convert it into a list if type(listofvar) is not list: listofvar = [listofvar] for var in listofvar: varname = var.varname if var.varname is not None else 'var' if var.dims is None and > 0: raise Exception("There is/are missing dimension(s) for "+\ var.varname) if var.dims is not None: # Save dimension arrays dimnames = var.getDimnames() axes = var.getAxes() for idim, dimname in enumerate(dimnames): if idim == recordax: dimsize = None else: dimsize =[idim] # check if the dimension has already been saved isnewdim = dimname not in ncfile.dimensions # the dimension name exists already and is not a record axis if not isnewdim and idim != recordax: olddim = ncfile.variables[dimname] # check if the dimensions are in fact the same one, # if not, it is a new dimension isnewdim = axes[idim].shape != olddim[:].shape or\ (not numpy.allclose(axes[idim], olddim[:])) # create new dimension if isnewdim: # Rename the dimension if it is unique but has name # collision within the file dimsuffix = '' newDcount = 0 while dimname+dimsuffix in ncfile.dimensions: newDcount += 1 dimsuffix = str(newDcount) dimnames[idim] += dimsuffix ncfile.createDimension(dimnames[idim], dimsize) ## saveddims.append(dimnames[idim]) endian = endian_code.get(axes[idim].dtype.byteorder, 'native') if not numpy.isscalar(axes[idim]): dimvar = ncfile.createVariable( dimnames[idim], numpy.dtype(axes[idim], (dimnames[idim],), endian=endian) dimvar[:] = axes[idim] dimvar.setncatts(var.dims[idim].attributes) varappend = False or (varname in ncfile.variables and appendall) if varname in ncfile.variables and not varappend: # Check again print("Variable {} exists. ".format(varname), "Append variable? [y/N]") append_code = sys.stdin.readline() if append_code[0].lower() == 'y': varappend = True else: # Likely an unintended collision of variable names # Change it! print("Rename variable as :") varname = sys.stdin.readline()[:-1] var.varname = varname if not varappend: endian = endian_code.get(, 'native') _ = ncfile.createVariable(varname, numpy.dtype(, dimnames, endian=endian, fill_value=var.getMissingValue()) var.ensureMasked() data2save = # print varname if varappend: if float(var.getMissingValue()) != \ float(ncfile.variables[varname].getncattr('_FillValue')): print("Warning: Existing var missing value:", var.getMissingValue(), "Appending var's missing value:", ncfile.variables[varname].getncattr('_FillValue')) oldvar = ncfile.variables[varname] olddim = ncfile.variables[var.getDimnames()[recordax]] oldnrec = ncfile.variables[varname].shape[recordax] newnrec =[recordax] s_lice = (slice(None),)*recordax + \ (slice(oldnrec, newnrec+oldnrec),) print("Appending variable:", varname) oldvar[s_lice] = data2save print("Appending dimensions:", var.getDimnames()[recordax]) olddim[oldnrec:newnrec+oldnrec] = axes[recordax] else: if data2save.ndim == 0: ncfile.variables[varname].assignValue(data2save) else: slice_obj = (slice(None),)*data2save.ndim ncfile.variables[varname][slice_obj] = data2save # Update attributes atts = {att:val for att, val in var.attributes.items() if att != '_FillValue'} ncfile.variables[varname].setncatts(atts) ncfile.close() if overwrite or os.path.exists(filename) is False: os.rename(filename+'', filename) print("Saved to file:", filename) elif appendall: print("Appended to file:", filename) else: print("Temporary file created:", filename, "", sep="")
[docs]def TimeSlices(var, lower, upper, toggle, no_continuous_duplicate_month=False): """ Return a time segment of the variable according to the lower (inclusive) and upper limits (inclusive) Args: var ( lower (numeric): lower time limit upper (numeric): upper time limit toggle (str): Y/m/d/H/M/S to select a particular time format no_continuous_duplicate_month (bool): default False; make sure the difference between calendar months in the time axis is always larger than or equal to 1; only suitable for dealing with monthly data. Returns: Examples: >>> # time segments in Nov, Dec, Jan and Feb >>> TimeSlices(var,11.,2.,"m") >>> # time segments from year 1990 to 2000 (inclusive) >>> TimeSlices(var,1990,2000,"Y") >>> '''Say 01-01-0001 and 31-01-0001 are two adjacent time >>> steps as far as monthly data is concerned, the second >>> time step 31-01-0001 should be considered as the >>> beginning of February and not as January. So we >>> want no_continuous_duplicate_month=True ''' >>> TimeSlices(var,1,2,"m",True) """ time = var.getDate(toggle, no_continuous_duplicate_month) taxis = var.getCAxes().index('T') if upper < lower: slices = (slice(None),)*taxis + \ (numpy.logical_or(time >= lower, time <= upper),) +\ (slice(None),)*( else: slices = (slice(None),)*taxis + \ (numpy.logical_and(time >= lower, time <= upper),) +\ (slice(None),)*( return var[slices]
[docs]def plot_vs_axis(var, axis, *args, **kwargs): axis = axis.upper() line = pylab.plot(var.getAxis(axis),, *args, **kwargs) # Use date for the time axis if axis == 'T': times = var.getAxis(axis) iticks = range(0, len(times), len(times)/10) xticks = [times[i] for i in iticks] dates = ["{}-{}-{}".format(*var.getDate()[i]) for i in iticks] pylab.gca().set_xticks(xticks) pylab.gca().set_xticklabels(dates, rotation=20) return line
def UseMapplot(f_pylab): """ A decorator for using mapplot functions on an object f_pylab is the pylab function for map plotting (e.g. contour, contourf,...) """ def plot_func(variable, *args, **kwargs): ''' Use mpl_toolkits.basemap.Basemap to plot Args: variable ( should be 2D (singlet dimension will be removed when calling in this function) basemap_kwargs (dict): optional. If provided, it is parsed to mpl_toolkits.basemap.Basemap while setting up the map Other arguments and keyword arguments are parsed to f_pylab (the pylab function f_pylab provided). Returns: m, cs (mpl_toolkits.basemap.Basemap, output of f_pylab) If the dimensions are not recognized as latitudes and longitudes, no map is made; f_pylab(x,y,data) is called and its output(s) are returned ''' basemap_kwargs = kwargs.pop("basemap_kwargs", None) # args needed for quiver args = list(args) # Squeeze variable input for iarg, arg in enumerate(args): if hasattr(arg, "squeeze"): args[iarg] = arg.squeeze() if isinstance(args[iarg], Variable): args[iarg] = var_squeeze = variable.squeeze() caxes = var_squeeze.getCAxes() data = if len(caxes) != 2: raise Exception('UseMapplot is supposed to be used on 2D data') if 'X' in caxes and 'Y' in caxes: # Lat-Lon plot lons = variable.getLongitude() lats = variable.getLatitude() m, cs = mapplot.MapSetup(f_pylab)( lons, lats, data, basemap_kwargs, *args, **kwargs) return m, cs elif caxes[-1] == 'Z': # Z axis is prefered as the vertical axis # and the data needs to be transposed data = data.T for iarg, arg in enumerate(args): if hasattr(arg, "T"): args[iarg] = arg.T y, x = var_squeeze.getAxes() return f_pylab(x, y, data, *args, **kwargs) return plot_func contour = UseMapplot(pylab.contour) contourf = UseMapplot(pylab.contourf) quiver = UseMapplot(pylab.quiver) pcolor = UseMapplot(pylab.pcolor)
[docs]def spatial_corr(var1, var2): return,[0, 1]
[docs]def regress(var1, var2): return Variable(data=geodat.signal.regress(,[0], dims=var1.dims[1:], varname="{}_{}".format(var1.varname, var2.varname), history="{} regress to {}".format(var1.varname, var2.varname))
Fork me on GitHub