climate-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ah...@apache.org
Subject svn commit: r1515644 - /incubator/climate/branches/RefactorPlots/rcmet/src/main/python/rcmes/toolkit/plots.py
Date Mon, 19 Aug 2013 22:26:09 GMT
Author: ahart
Date: Mon Aug 19 22:26:08 2013
New Revision: 1515644

URL: http://svn.apache.org/r1515644
Log:
CLIMATE-259: updates to plots.py to support generation of time series, taylor, subregion,
and portrait diagrams

Modified:
    incubator/climate/branches/RefactorPlots/rcmet/src/main/python/rcmes/toolkit/plots.py

Modified: incubator/climate/branches/RefactorPlots/rcmet/src/main/python/rcmes/toolkit/plots.py
URL: http://svn.apache.org/viewvc/incubator/climate/branches/RefactorPlots/rcmet/src/main/python/rcmes/toolkit/plots.py?rev=1515644&r1=1515643&r2=1515644&view=diff
==============================================================================
--- incubator/climate/branches/RefactorPlots/rcmet/src/main/python/rcmes/toolkit/plots.py
(original)
+++ incubator/climate/branches/RefactorPlots/rcmet/src/main/python/rcmes/toolkit/plots.py
Mon Aug 19 22:26:08 2013
@@ -19,208 +19,634 @@
 
 # Import Statements
 
-from math import floor, log
-from matplotlib import pyplot as plt
-from mpl_toolkits.basemap import Basemap
+'''
+Classes: 
+    Plotter - Visualizes pre-calculated metrics
+'''
+
+import os
+from tempfile import TemporaryFile
 import matplotlib as mpl
+import matplotlib.pyplot as plt
+from mpl_toolkits.basemap import Basemap
+from mpl_toolkits.axes_grid1 import ImageGrid
+import scipy.stats.mstats as mstats
 import numpy as np
-import os
+import numpy.ma as ma
+from utils.taylor import TaylorDiagram
+#from toolkit import plots
+
 
 def pow_round(x):
     '''
      Function to round x to the nearest power of 10
     '''
     return 10 ** (floor(log(x, 10) - log(0.5, 10)))
-
-def calc_nice_color_bar_values(mymin, mymax, target_nlevs):
+def _nice_intervals(data, nlevs):
     '''
-     Function to help make nicer plots. 
-     
-     Calculates an appropriate min, max and number of intervals to use in a color bar 
-     such that the labels come out as round numbers.
-    
-     i.e. often, the color bar labels will come out as  0.1234  0.2343 0.35747 0.57546
-     when in fact you just want  0.1, 0.2, 0.3, 0.4, 0.5 etc
-    
-    
-     Method::
-         Adjusts the max,min and nlevels slightly so as to provide nice round numbers.
-    
-     Input::
-        mymin        - minimum of data range (or first guess at minimum color bar value)
-        mymax        - maximum of data range (or first guess at maximum color bar value)
-        target_nlevs - approximate number of levels/color bar intervals you would like to
have
+    Purpose::
+        Calculates nice intervals between each color level for colorbars
+        and contour plots. The target minimum and maximum color levels are
+        calculated by taking the minimum and maximum of the distribution
+        after cutting off the tails to remove outliers. 
     
-     Output::
-        newmin       - minimum value of color bar to use
-        newmax       - maximum value of color bar to use
-        new_nlevs    - number of intervals in color bar to use
-        * when all of the above are used, the color bar should have nice round number labels.
+    Input::
+        data - an array of data to be plotted
+        nlevs - an int giving the target number of intervals
+        
+    Output::
+        clevs - A list of floats for the resultant colorbar levels
     '''
-    myrange = mymax - mymin
-    # Find target color bar label interval, given target number of levels.
-    #  NB. this is likely to be not a nice rounded number.
-    target_interval = myrange / float(target_nlevs)
-    
-    # Find power of 10 that the target interval lies in
-    nearest_ten = pow_round(target_interval)
+    # Find the min and max levels by cutting off the tails of the distribution
+    # This mitigates the influence of outliers
+    data = data.ravel()
+    mnlvl = mstats.scoreatpercentile(data, 5)
+    mxlvl = mstats.scoreatpercentile(data, 95)
+    locator = mpl.ticker.MaxNLocator(nlevs)
+    clevs = locator.tick_values(mnlvl, mxlvl)
+    
+    # Make sure the bounds of clevs are reasonable since sometimes
+    # MaxNLocator gives values outside the domain of the input data
+    clevs = clevs[(clevs >= mnlvl) & (clevs <= mxlvl)]
+    return clevs
+
+def _best_grid_shape(nplots, oldshape):
+    '''
+    Purpose::
+        Calculate a better grid shape in case the user enters more columns
+        and rows than needed to fit a given number of subplots.
+        
+    Input::
+        nplots - an int giving the number of plots that will be made
+        oldshape - a tuple denoting the desired grid shape (nrows, ncols) for arranging
+                    the subplots originally requested by the user. 
     
-    # Possible interval levels, 
-    #  i.e.  labels of 1,2,3,4,5 etc are OK, 
-    #        labels of 2,4,6,8,10 etc are OK too
-    #        labels of 3,6,9,12 etc are NOT OK (as defined below)
-    #  NB.  this is also true for any multiple of 10 of these values
-    #    i.e.  0.01,0.02,0.03,0.04 etc are OK too.
-    pos_interval_levels = np.array([1, 2, 5])
+    Output::
+        newshape - the smallest possible subplot grid shape needed to fit nplots
+    '''
+    nrows, ncols = oldshape
+    size = nrows * ncols
+    diff = size - nplots
+    if diff < 0:
+        raise ValueError('gridshape=(%d, %d): Cannot fit enough subplots for data' %(nrows,
ncols))
+    else:
+        # If the user enters an excessively large number of
+        # rows and columns for gridshape, automatically
+        # correct it so that it fits only as many plots
+        # as needed
+        while diff >= ncols:
+            nrows -= 1
+            size = nrows * ncols
+            diff = size - nplots
+            
+        # Don't forget to remove unnecessary columns too
+        if nrows == 1:
+            ncols = nplots
+            
+        newshape = nrows, ncols
+        return newshape
     
-    # Find possible intervals to use within this power of 10 range
-    candidate_intervals = (pos_interval_levels * nearest_ten)
+def _fig_size(gridshape):
+    '''
+    Purpose::
+        Calculates the figure dimensions from a subplot gridshape
+        
+    Input::
+        gridshape - Tuple denoting the subplot gridshape
+        
+    Output::
+        width - float for width of the figure in inches
+        height - float for height of the figure in inches
+    '''
+    nrows, ncols = gridshape
     
-    # Find which of the candidate levels is closest to the target level
-    absdiff = abs(target_interval - candidate_intervals)
+    # Assuming base dimensions of 8.5" x 5.5". May change this later to be
+    # user defined.
+    if nrows >= ncols:
+        width, height = 8.5, 5.5 * nrows / ncols
+    else:
+        width, height = 8.5 * ncols / nrows, 5.5
+        
+    return width, height
     
-    rounded_interval = candidate_intervals[np.where(absdiff == min(absdiff))]
+def draw_taylor_diagram(data, data_name,refname, fname, fmt='png', ptitle='', 
+                      pos='upper right', frameon=False, radmax=1.5):
+    '''
+    Purpose::
+        Draws a Taylor diagram
+        
+    Input::
+        data - an Nx2 array containing normalized standard deviations,
+               correlation coefficients
+        dataname - N array containing names of evaluation datasets
+        refname - The name of the reference datasets
+        fname  - a string specifying the filename of the plot
+        fmt  - an optional string specifying the filetype, default is .png
+        ptitle - an optional string specifying the plot title
+        pos - an optional string or tuple of float for determining 
+                    the position of the legend
+        frameon - an optional boolean that determines whether to draw a frame
+                        around the legend box
+        radmax - an optional float to adjust the extent of the axes in terms of
+                 standard deviation.
+    '''
+    fig = plt.figure()
+    fig.suptitle(ptitle)            
+ 
+    dia = TaylorDiagram (1, fig=fig, rect=111, label=refname, radmax=radmax)
+    for i, (stddev, corrcoef) in enumerate(data):
+        name = data_name[i]
+        dia.add_sample(stddev, corrcoef, marker='$%d$' % (i + 1), ms=6, label=name)
+    
+    legend = fig.legend(dia.samplePoints, [p.get_label() for p in dia.samplePoints], handlelength=0.,

+                        prop={'size': 10}, numpoints=1, loc=pos)
+    legend.draw_frame(frameon)
+    fig.savefig('%s.%s' %(fname, fmt), bbox_inches='tight')    
+    fig.clf()
     
-    # Define actual nlevels to use in colorbar
-    nlevels = myrange / rounded_interval
+def draw_subregions(subregions, lats, lons, fname, fmt='png', ptitle='',
+                   parallels=None, meridians=None, subregion_masks=None):
+    '''
+    Purpose::
+        Function to draw subregion domain(s) on a map
+        
+    Input::
+        subregions - a list of subRegion objects
+        lats - array of latitudes
+        lons - array of longitudes
+        fname  - a string specifying the filename of the plot
+        fmt  - an optional string specifying the filetype, default is .png
+        ptitle - an optional string specifying plot title
+        parallels - an optional list of ints or floats for the parallels to be drawn 
+        meridians - an optional list of ints or floats for the meridians to be drawn
+        subregion_masks - optional dictionary of boolean arrays for each subRegion
+                         for giving finer control of the domain to be drawn, by default
+                         the entire domain is drawn. 
+    '''
+    # Set up the figure
+    fig = plt.figure()
+    fig.set_size_inches((8.5, 11.))
+    fig.dpi = 300
+    ax = fig.add_subplot(111)
     
-    # Define the color bar labels
-    newmin = mymin - mymin % rounded_interval
+    # Determine the map boundaries and construct a Basemap object
+    lonmin = lons.min()
+    lonmax = lons.max()
+    latmin = lats.min()
+    latmax = lats.max()
+    m = Basemap(projection='cyl', llcrnrlat=latmin, urcrnrlat=latmax,
+                llcrnrlon=lonmin, urcrnrlon=lonmax, resolution='l', ax=ax)
     
-    all_labels = np.arange(newmin, mymax + rounded_interval, rounded_interval) 
+    # Draw the borders for coastlines and countries
+    m.drawcoastlines(linewidth=1)
+    m.drawcountries(linewidth=.75)
+    m.drawstates()
+    
+    # Create default meridians and parallels. The interval between
+    # them should be 1, 5, 10, 20, 30, or 40 depending on the size
+    # of the domain
+    length = max((latmax - latmin), (lonmax - lonmin)) / 5
+    if length <= 1:
+        dlatlon = 1
+    elif length <= 5:
+        dlatlon = 5
+    else:
+        dlatlon = np.round(length, decimals=-1)
+        
+    if meridians is None:
+        meridians = np.r_[np.arange(0, -180, -dlatlon)[::-1], np.arange(0, 180, dlatlon)]
+    if parallels is None:
+        parallels = np.r_[np.arange(0, -90, -dlatlon)[::-1], np.arange(0, 90, dlatlon)]
+        
+    # Draw parallels / meridians
+    m.drawmeridians(meridians, labels=[0, 0, 0, 1], linewidth=.75, fontsize=10)
+    m.drawparallels(parallels, labels=[1, 0, 0, 1], linewidth=.75, fontsize=10)
+
+    # Set up the color scaling
+    cmap = plt.cm.rainbow
+    norm = mpl.colors.BoundaryNorm(np.arange(1, len(subregions) + 3), cmap.N)
+    
+    # Process the subregions 
+    for i, reg in enumerate(subregions):
+        if subregion_masks is not None and reg.name in subregion_masks.keys():
+            domain = (i + 1) * subregion_masks[reg.name]
+        else:        
+            domain = (i + 1) * np.ones((2, 2))
+        
+        nlats, nlons = domain.shape
+        domain = ma.masked_equal(domain, 0)
+        reglats = np.linspace(reg.latmin, reg.latmax, nlats)
+        reglons = np.linspace(reg.lonmin, reg.lonmax, nlons)            
+        reglons, reglats = np.meshgrid(reglons, reglats)
+        
+        # Convert to to projection coordinates. Not really necessary
+        # for cylindrical projections but keeping it here in case we need
+        # support for other projections.
+        x, y = m(reglons, reglats)
+        
+        # Draw the subregion domain
+        m.pcolormesh(x, y, domain, cmap=cmap, norm=norm, alpha=.5)
+        
+        # Label the subregion
+        xm, ym = x.mean(), y.mean()
+        m.plot(xm, ym, marker='$%s$' %(reg.name), markersize=12, color='k')
+    
+    # Add the title
+    ax.set_title(ptitle)
+    
+    # Save the figure
+    fig.savefig('%s.%s' %(fname, fmt), bbox_inches='tight', dpi=fig.dpi)
+    fig.clf()
+
+def draw_time_series(datasets, times, labels, fname, fmt='png', gridshape=(1, 1), 
+                   xlabel='', ylabel='', ptitle='', subtitles=None, 
+                   label_month=False, yscale='linear'):
+    '''
+    Purpose::
+        Function to draw a time series plot
+     
+    Input:: 
+        datasets - a 3d array of time series
+        times - a list of python datetime objects
+        labels - a list of strings with the names of each set of data
+        fname - a string specifying the filename of the plot
+        fmt - an optional string specifying the output filetype
+        gridshape - optional tuple denoting the desired grid shape (nrows, ncols) for arranging
+                    the subplots. 
+        xlabel - a string specifying the x-axis title
+        ylabel - a string specifying the y-axis title
+        ptitle - a string specifying the plot title
+        subtitles - an optional list of strings specifying the title for each subplot
+        label_month - optional bool to toggle drawing month labels
+        yscale - optional string for setting the y-axis scale, 'linear' for linear
+                 and 'log' for log base 10.
+    '''
+    # Handle the single plot case. 
+    if datasets.ndim == 2:
+        datasets = datasets.reshape(1, *datasets.shape)
+
+    # Make sure gridshape is compatible with input data
+    nplots = datasets.shape[0]
+    gridshape = _best_grid_shape(nplots, gridshape)
+        
+    # Set up the figure
+    width, height = _fig_size(gridshape)
+    fig = plt.figure()
+    fig.set_size_inches((width, height))
+    fig.dpi = 300
     
-    newmin = all_labels.min()  
-    newmax = all_labels.max()
+    # Make the subplot grid
+    grid = ImageGrid(fig, 111,
+                     nrows_ncols=gridshape,
+                     axes_pad=0.3,
+                     share_all=True,
+                     add_all=True,
+                     ngrids=nplots,
+                     label_mode='L',
+                     aspect=False,
+                     cbar_mode='single',
+                     cbar_location='bottom',
+                     cbar_size=.05,
+                     cbar_pad=.20
+                     )
+    
+    # Make the plots
+    for i, ax in enumerate(grid):
+        data = datasets[i]
+        if label_month:
+            xfmt = mpl.dates.DateFormatter('%b')
+            xloc = mpl.dates.MonthLocator()
+            ax.xaxis.set_major_formatter(xfmt)
+            ax.xaxis.set_major_locator(xloc)
     
-    new_nlevs = int(len(all_labels)) - 1
+        # Set the y-axis scale
+        ax.set_yscale(yscale)
     
-    return newmin, newmax, new_nlevs
-
-def draw_cntr_map_single(pVar, lats, lons, mnLvl, mxLvl, pTitle, pName, pType = 'png', cMap
= None):
+        # Set up list of lines for legend
+        lines = []
+        ymin, ymax = 0, 0
+                
+        # Plot each line
+        for tSeries in data:
+            line = ax.plot_date(times, tSeries, '')
+            lines.extend(line)
+            cmin, cmax = tSeries.min(), tSeries.max()
+            ymin = min(ymin, cmin)
+            ymax = max(ymax, cmax)
+            
+        # Add a bit of padding so lines don't touch bottom and top of the plot
+        ymin = ymin - ((ymax - ymin) * 0.1)
+        ymax = ymax + ((ymax - ymin) * 0.1)
+        ax.set_ylim((ymin, ymax))
+        
+        # Set the subplot title if desired
+        if subtitles is not None:
+            ax.set_title(subtitles[i], fontsize='small')
+        
+    # Create a master axes rectangle for figure wide labels
+    fax = fig.add_subplot(111, frameon=False)
+    fax.tick_params(labelcolor='none', top='off', bottom='off', left='off', right='off')
+    fax.set_ylabel(ylabel)
+    fax.set_title(ptitle, fontsize=16)
+    fax.title.set_y(1.04)
+    
+    # Create the legend using a 'fake' colorbar axes. This lets us have a nice
+    # legend that is in sync with the subplot grid
+    cax = ax.cax
+    cax.set_frame_on(False)
+    cax.set_xticks([])
+    cax.set_yticks([])
+    cax.legend((lines), labels, loc='upper center', ncol=10, fontsize='small', 
+                   mode='expand', frameon=False)
+    
+    # Note that due to weird behavior by axes_grid, it is more convenient to
+    # place the x-axis label relative to the colorbar axes instead of the
+    # master axes rectangle.
+    cax.set_title(xlabel, fontsize=12)
+    cax.title.set_y(-1.5)
+    
+    # Rotate the x-axis tick labels
+    for ax in grid:
+        for xtick in ax.get_xticklabels():
+            xtick.set_ha('right')
+            xtick.set_rotation(30)
+    
+    # Save the figure
+    fig.savefig('%s.%s' %(fname, fmt), bbox_inches='tight', dpi=fig.dpi)
+    fig.clf()
+
+def draw_contour_map(dataset, lats, lons, fname, fmt='png', gridshape=(1, 1),
+                   clabel='', ptitle='', subtitles=None, cmap=None, 
+                   clevs=None, nlevs=10, parallels=None, meridians=None,
+                   extend='neither'):
     '''
     Purpose::
-        Plots a filled contour map.
+        Create a multiple panel contour map plot.
        
     Input::
-        pVar - 2d array of the field to be plotted with shape (nLon, nLat)
-        lon - array of longitudes 
-        lat - array of latitudes
-        mnLvl - an integer specifying the minimum contour level
-        mxLvl - an integer specifying the maximum contour level
-        pTitle - a string specifying plot title
-        pName  - a string specifying the filename of the plot
-        pType  - an optional string specifying the filetype, default is .png
-        cMap - an optional matplotlib.LinearSegmentedColormap object denoting the colormap,
-               default is matplotlib.pyplot.cm.jet
-
-        TODO: Let user specify map projection, whether to mask bodies of water??
-        
+        dataset -  3d array of the field to be plotted with shape (nT, nLon, nLat)
+        lats - array of latitudes
+        lons - array of longitudes
+        fname  - a string specifying the filename of the plot
+        fmt  - an optional string specifying the filetype, default is .png
+        gridshape - optional tuple denoting the desired grid shape (nrows, ncols) for arranging
+                    the subplots. 
+        clabel - an optional string specifying the colorbar title
+        ptitle - an optional string specifying plot title
+        subtitles - an optional list of strings specifying the title for each subplot
+        cmap - an optional matplotlib.LinearSegmentedColormap object denoting the colormap
+        clevs - an optional list of ints or floats specifying contour levels
+        nlevs - an optional integer specifying the target number of contour levels if
+                clevs is None        
+        parallels - an optional list of ints or floats for the parallels to be drawn 
+        meridians - an optional list of ints or floats for the meridians to be drawn
+        extend - an optional string to toggle whether to place arrows at the colorbar
+             boundaries. Default is 'neither', but can also be 'min', 'max', or
+             'both'. Will be automatically set to 'both' if clevs is None.     
     '''
-    if cMap is None:
-        cMap = plt.cm.jet
+    # Handle the single plot case. Meridians and Parallels are not labeled for
+    # multiple plots to save space.
+    if dataset.ndim == 2 or (dataset.ndim == 3 and dataset.shape[0] == 1):
+        if dataset.ndim == 2:
+            dataset = dataset.reshape(1, *dataset.shape)
+        mlabels = [0, 0, 0, 1]
+        plabels = [1, 0, 0, 1]
+    else:
+        mlabels = [0, 0, 0, 0]
+        plabels = [0, 0, 0, 0]
+        
+    # Make sure gridshape is compatible with input data
+    nplots = dataset.shape[0]
+    gridshape = _best_grid_shape(nplots, gridshape)
         
     # Set up the figure
     fig = plt.figure()
-    ax = fig.gca()
-
+    fig.set_size_inches((8.5, 11.))
+    fig.dpi = 300
+    
+    # Make the subplot grid
+    grid = ImageGrid(fig, 111,
+                     nrows_ncols=gridshape,
+                     axes_pad=0.3,
+                     share_all=True,
+                     add_all=True,
+                     ngrids=nplots,
+                     label_mode='L',
+                     cbar_mode='single',
+                     cbar_location='bottom',
+                     cbar_size=.15,
+                     cbar_pad='0%'
+                     )
+        
     # Determine the map boundaries and construct a Basemap object
-    lonMin = lons.min()
-    lonMax = lons.max()
-    latMin = lats.min()
-    latMax = lats.max()
-    m = Basemap(projection = 'cyl', llcrnrlat = latMin, urcrnrlat = latMax,
-            llcrnrlon = lonMin, urcrnrlon = lonMax, resolution = 'l', ax = ax)
-
-    # Draw the borders for coastlines and countries
-    m.drawcoastlines(linewidth = 1)
-    m.drawcountries(linewidth = .75)
+    lonmin = lons.min()
+    lonmax = lons.max()
+    latmin = lats.min()
+    latmax = lats.max()
+    m = Basemap(projection = 'cyl', llcrnrlat = latmin, urcrnrlat = latmax,
+                llcrnrlon = lonmin, urcrnrlon = lonmax, resolution = 'l')
     
-    # Draw 6 parallels / meridians.
-    m.drawmeridians(np.linspace(lonMin, lonMax, 5), labels = [0, 0, 0, 1])
-    m.drawparallels(np.linspace(latMin, latMax, 5), labels = [1, 0, 0, 1])
-
     # Convert lats and lons to projection coordinates
     if lats.ndim == 1 and lons.ndim == 1:
         lons, lats = np.meshgrid(lons, lats)
+    
+    # Calculate contour levels if not given
+    if clevs is None:
+        # Cut off the tails of the distribution
+        # for more representative contour levels
+        clevs = _nice_intervals(dataset, nlevs)
+        extend = 'both'
+    
+    if cmap is None:
+        cmap = plt.cm.coolwarm
+    
+    # Create default meridians and parallels. The interval between
+    # them should be 1, 5, 10, 20, 30, or 40 depending on the size
+    # of the domain
+    length = max((latmax - latmin), (lonmax - lonmin)) / 5
+    if length <= 1:
+        dlatlon = 1
+    elif length <= 5:
+        dlatlon = 5
+    else:
+        dlatlon = np.round(length, decimals = -1)
+    if meridians is None:
+        meridians = np.r_[np.arange(0, -180, -dlatlon)[::-1], np.arange(0, 180, dlatlon)]
+    if parallels is None:
+        parallels = np.r_[np.arange(0, -90, -dlatlon)[::-1], np.arange(0, 90, dlatlon)]
+            
     x, y = m(lons, lats)
-
-    # Plot data with filled contours
-    nsteps = 24
-    mnLvl, mxLvl, nsteps = calc_nice_color_bar_values(mnLvl, mxLvl, nsteps)
-    spLvl = (mxLvl - mnLvl) / nsteps
-    clevs = np.arange(mnLvl, mxLvl, spLvl)
-    cs = m.contourf(x, y, pVar, cmap = cMap)
-
-    # Add a colorbar and save the figure
-    cbar = m.colorbar(cs, ax = ax, pad = .05)
-    plt.title(pTitle)
-    fig.savefig('%s.%s' %(pName, pType))
-
-def draw_time_series_plot(data, times, myfilename, myworkdir, data2='', mytitle='', ytitle='Y',
xtitle='time', year_labels=True):
+    for i, ax in enumerate(grid):        
+        # Load the data to be plotted
+        data = dataset[i]
+        m.ax = ax
+        
+        # Draw the borders for coastlines and countries
+        m.drawcoastlines(linewidth=1)
+        m.drawcountries(linewidth=.75)
+        
+        # Draw parallels / meridians
+        m.drawmeridians(meridians, labels=mlabels, linewidth=.75, fontsize=10)
+        m.drawparallels(parallels, labels=plabels, linewidth=.75, fontsize=10)
+        
+        # Draw filled contours
+        cs = m.contourf(x, y, data, cmap=cmap, levels=clevs, extend=extend)
+        
+        # Add title
+        if subtitles is not None:
+            ax.set_title(subtitles[i], fontsize='small')
+
+    # Add colorbar
+    cbar = fig.colorbar(cs, cax=ax.cax, drawedges=True, orientation='horizontal',
+                        extendfrac='auto')
+    cbar.set_label(clabel)
+    cbar.set_ticks(clevs)
+    cbar.ax.xaxis.set_ticks_position('none')
+    cbar.ax.yaxis.set_ticks_position('none')
+        
+    # This is an ugly hack to make the title show up at the correct height.
+    # Basically save the figure once to achieve tight layout and calculate
+    # the adjusted heights of the axes, then draw the title slightly above
+    # that height and save the figure again
+    fig.savefig(TemporaryFile(), bbox_inches='tight', dpi=fig.dpi)
+    ymax = 0
+    for ax in grid:
+        bbox = ax.get_position()
+        ymax = max(ymax, bbox.ymax)
+    
+    # Add figure title
+    fig.suptitle(ptitle, y=ymax + .06, fontsize=16)
+    fig.savefig('%s.%s' %(fname, fmt), bbox_inches='tight', dpi=fig.dpi)
+    fig.clf()
+
+def draw_portrait_diagram(datasets, rowlabels, collabels, fname, fmt='png', 
+                        gridshape=(1, 1), xlabel='', ylabel='', clabel='', 
+                        ptitle='', subtitles=None, cmap=None, clevs=None, 
+                        nlevs=10, extend='neither'):
     '''
-     Purpose::
-         Function to draw a time series plot
+    Purpose::
+        Makes a portrait diagram plot.
+        
+    Input::
+        datasets - 3d array of the field to be plotted. The second dimension 
+                  should correspond to the number of rows in the diagram and the
+                  third should correspond to the number of columns.
+        rowlabels - a list of strings denoting labels for each row
+        collabels - a list of strings denoting labels for each column
+        fname - a string specifying the filename of the plot
+        fmt - an optional string specifying the output filetype
+        gridshape - optional tuple denoting the desired grid shape (nrows, ncols) for arranging
+                    the subplots. 
+        xlabel - an optional string specifying the x-axis title
+        ylabel - an optional string specifying the y-axis title
+        clabel - an optional string specifying the colorbar title
+        ptitle - a string specifying the plot title
+        subtitles - an optional list of strings specifying the title for each subplot
+        cmap - an optional matplotlib.LinearSegmentedColormap object denoting the colormap
+        clevs - an optional list of ints or floats specifying colorbar levels
+        nlevs - an optional integer specifying the target number of contour levels if
+                clevs is None        
+        extend - an optional string to toggle whether to place arrows at the colorbar
+             boundaries. Default is 'neither', but can also be 'min', 'max', or
+             'both'. Will be automatically set to 'both' if clevs is None.     
+        
+    '''  
+    # Handle the single plot case.
+    if datasets.ndim == 2:
+        datasets = datasets.reshape(1, *datasets.shape)
+    
+    nplots = datasets.shape[0]
+    
+    # Make sure gridshape is compatible with input data
+    gridshape = _best_grid_shape(nplots, gridshape)
+    
+    # Row and Column labels must be consistent with the shape of
+    # the input data too
+    prows, pcols = datasets.shape[1:]
+    if len(rowlabels) != prows or len(collabels) != pcols:
+        raise ValueError('rowlabels and collabels must have %d and %d elements respectively'
%(prows, pcols))
      
-     Input:: 
-         data - a masked numpy array of data masked by missing values		
-         times - a list of python datetime objects
-         myfilename - stub of png file created e.g. 'myfile' -> myfile.png
-         myworkdir - directory to save images in
-         data2 - (optional) second data line to plot assumes same time values)
-         mytitle - (optional) chart title
-    	 xtitle - (optional) y-axis title
-    	 ytitle - (optional) y-axis title
-    
-     Output::
-         no data returned from function
-         Image file produced with name {filename}.png
-    '''
-    print 'Producing time series plot'
-
+    # Set up the figure
+    width, height = _fig_size(gridshape)
     fig = plt.figure()
-    ax = fig.gca()
-
-    if year_labels == False:
-        xfmt = mpl.dates.DateFormatter('%b')
-        ax.xaxis.set_major_formatter(xfmt)
-
-    # x-axis title
-    plt.xlabel(xtitle)
-
-    # y-axis title
-    plt.ylabel(ytitle)
-
-    # Main title
-    fig.suptitle(mytitle, fontsize=12)
-
-    # Set y-range to sensible values
-    # NB. if plotting two lines, then make sure range appropriate for both datasets
-    ymin = data.min()
-    ymax = data.max()
-
-    # If data2 has been passed in, then set plot range to fit both lines.
-    # NB. if data2 has been passed in, then it is an array, otherwise it defaults to an empty
string
-    if type(data2) != str:
-        ymin = min(data.min(), data2.min())
-        ymax = max(data.max(), data2.max())
-
-    # add a bit of padding so lines don't touch bottom and top of the plot
-    ymin = ymin - ((ymax - ymin) * 0.1)
-    ymax = ymax + ((ymax - ymin) * 0.1)
-
-    # Set y-axis range
-    plt.ylim((ymin, ymax))
-
-    # Make plot, specifying marker style ('x'), linestyle ('-'), linewidth and line color
-    line1 = ax.plot_date(times, data, 'bo-', markersize=6, linewidth=2, color='#AAAAFF')
-    # Make second line, if data2 has been passed in.
-    # TODO:  Handle the optional second dataset better.  Maybe set the Default to None instead

-    # of an empty string
-    if type(data2) != str:
-        line2 = ax.plot_date(times, data2, 'rx-', markersize=6, linewidth=2, color='#FFAAAA')
-        lines = []
-        lines.extend(line1)
-        lines.extend(line2)
-        fig.legend((lines), ('model', 'obs'), loc='upper right')
+    fig.set_size_inches((width, height))
+    fig.dpi = 300
+    
+    # Make the subplot grid
+    grid = ImageGrid(fig, 111,
+                     nrows_ncols=gridshape,
+                     axes_pad=0.4,
+                     share_all=True,
+                     aspect=False,
+                     add_all=True,
+                     ngrids=nplots,
+                     label_mode='all',
+                     cbar_mode='single',
+                     cbar_location='bottom',
+                     cbar_size=.15,
+                     cbar_pad='3%'
+                     )
+    
+    # Calculate colorbar levels if not given
+    if clevs is None:
+        # Cut off the tails of the distribution
+        # for more representative colorbar levels
+        clevs = _nice_intervals(datasets, nlevs)
+        extend = 'both'
+        
+    if cmap is None:
+        cmap = plt.cm.coolwarm
+        
+    norm = mpl.colors.BoundaryNorm(clevs, cmap.N)
+    
+    # Do the plotting
+    for i, ax in enumerate(grid):
+        data = datasets[i]
+        cs = ax.matshow(data, cmap=cmap, aspect='auto', origin='lower', norm=norm)
+        
+        # Add grid lines
+        ax.xaxis.set_ticks(np.arange(data.shape[1] + 1))
+        ax.yaxis.set_ticks(np.arange(data.shape[0] + 1))
+        x = (ax.xaxis.get_majorticklocs() - .5)
+        y = (ax.yaxis.get_majorticklocs() - .5)
+        ax.vlines(x, y.min(), y.max())
+        ax.hlines(y, x.min(), x.max())
+        
+        # Configure ticks
+        ax.xaxis.tick_bottom()
+        ax.xaxis.set_ticks_position('none')
+        ax.yaxis.set_ticks_position('none')
+        ax.set_xticklabels(collabels, fontsize='xx-small')
+        ax.set_yticklabels(rowlabels, fontsize='xx-small')
+        
+        # Add axes title
+        if subtitles is not None:
+            ax.text(0.5, 1.04, subtitles[i], va='center', ha='center', 
+                    transform = ax.transAxes, fontsize='small')
+    
+    # Create a master axes rectangle for figure wide labels
+    fax = fig.add_subplot(111, frameon=False)
+    fax.tick_params(labelcolor='none', top='off', bottom='off', left='off', right='off')
+    fax.set_ylabel(ylabel)
+    fax.set_title(ptitle, fontsize=16)
+    fax.title.set_y(1.04)
+    
+    # Add colorbar
+    cax = ax.cax
+    cbar = fig.colorbar(cs, cax=cax, norm=norm, boundaries=clevs, drawedges=True, 
+                        extend=extend, orientation='horizontal', extendfrac='auto')
+    cbar.set_label(clabel)
+    cbar.set_ticks(clevs)
+    cbar.ax.xaxis.set_ticks_position('none')
+    cbar.ax.yaxis.set_ticks_position('none')
+    
+    # Note that due to weird behavior by axes_grid, it is more convenient to
+    # place the x-axis label relative to the colorbar axes instead of the
+    # master axes rectangle.
+    cax.set_title(xlabel, fontsize=12)
+    cax.title.set_y(1.5)
+    
+    # Save the figure
+    fig.savefig('%s.%s' %(fname, fmt), bbox_inches='tight', dpi=fig.dpi)
+    fig.clf()
 
-    fig.savefig(myworkdir + '/' + myfilename + '.png')



Mime
View raw message